// $Id: TypeValidator.cpp,v 1.10 2008/04/04 17:29:37 david Exp $ -*- c++ -*-
/*
 * $Log: TypeValidator.cpp,v $
 */

#include <string>
#include "semantics/TypeValidator.h"

#include <cdk/nodes/Node.h>
#include <cdk/nodes/Data.h>
#include <cdk/nodes/Nil.h>
#include <cdk/nodes/Composite.h>
#include <cdk/nodes/Sequence.h>

#include <cdk/nodes/expressions/Integer.h>
#include <cdk/nodes/expressions/Double.h>
#include <cdk/nodes/expressions/String.h>
#include <cdk/nodes/expressions/Identifier.h>
#include <cdk/nodes/expressions/NEG.h>
#include <cdk/nodes/expressions/ADD.h>
#include <cdk/nodes/expressions/SUB.h>
#include <cdk/nodes/expressions/MUL.h>
#include <cdk/nodes/expressions/DIV.h>
#include <cdk/nodes/expressions/MOD.h>
#include <cdk/nodes/expressions/LT.h>
#include <cdk/nodes/expressions/GT.h>
#include <cdk/nodes/expressions/GE.h>
#include <cdk/nodes/expressions/LE.h>
#include <cdk/nodes/expressions/NE.h>
#include <cdk/nodes/expressions/EQ.h>

#include "nodes/ReadNode.h"
#include "nodes/PrintNode.h"
#include "nodes/AssignmentNode.h"
#include "nodes/WhileNode.h"
#include "nodes/IfNode.h"
#include "nodes/IfElseNode.h"
#include "nodes/ProgramNode.h"

//---------------------------------------------------------------------------

TypeValidator::~TypeValidator() { os().flush(); }

//---------------------------------------------------------------------------

void TypeValidator::processComposite(cdk::node::Composite *const node, int lvl) {
  for (size_t i = 0; i < node->size(); i++)
    node->at(i)->accept(this, lvl+2);
}

void TypeValidator::processSequence(cdk::node::Sequence *const node, int lvl) {
  for (size_t i = 0; i < node->size(); i++) {
    cdk::node::Node *n = node->node(i);
    if (n == NULL) break;
    n->accept(this, lvl+2);
  }
}

//---------------------------------------------------------------------------

void TypeValidator::processDouble(cdk::node::expression::Double *const node, int lvl) {
  node->type(ExpressionType(8, ExpressionType::TYPE_DOUBLE));
}

void TypeValidator::processInteger(cdk::node::expression::Integer *const node, int lvl) {
  node->type(ExpressionType(4, ExpressionType::TYPE_INT));
}

void TypeValidator::processString(cdk::node::expression::String *const node, int lvl) {
  node->type(ExpressionType(4, ExpressionType::TYPE_STRING));
}

void TypeValidator::processIdentifier(cdk::node::expression::Identifier *const node, int lvl) {
  node->type(ExpressionType(4, ExpressionType::TYPE_INT));
}

//---------------------------------------------------------------------------
void TypeValidator::processNEG(cdk::node::expression::NEG *const node, int lvl) {
  node->argument()->accept(this, lvl);
  if (node->argument()->type().name() == ExpressionType::TYPE_INT)
    node->type(node->argument()->type());
  else throw std::string("integer expression expected");
}

//---------------------------------------------------------------------------
//protected:
void TypeValidator::processIntOnlyExpression(cdk::node::expression::BinaryExpression *const node, int lvl) {
  node->left()->accept(this, lvl+2);
  if (node->left()->type().name() != ExpressionType::TYPE_INT)
    throw std::string("integer expression expected in binary operator (left)");

  node->right()->accept(this, lvl+2);
  if (node->right()->type().name() != ExpressionType::TYPE_INT)
    throw std::string("integer expression expected in binary operator (right)");

  node->type(ExpressionType(4, ExpressionType::TYPE_INT));
}

//---------------------------------------------------------------------------
//protected:
void TypeValidator::processScalarLogicalExpression(cdk::node::expression::BinaryExpression *const node, int lvl) {
  node->left()->accept(this, lvl+2);
  if (node->left()->type().name() != ExpressionType::TYPE_INT)
    throw std::string("integer expression expected in binary logical expression (left)");

  node->right()->accept(this, lvl+2);
  if (node->right()->type().name() != ExpressionType::TYPE_INT)
    throw std::string("integer expression expected in binary logical expression (right)");

  node->type(ExpressionType(4, ExpressionType::TYPE_BOOLEAN));
}

//protected:
void TypeValidator::processBooleanLogicalExpression(cdk::node::expression::BinaryExpression *const node, int lvl) {
  node->left()->accept(this, lvl+2);
  if (node->left()->type().name() != ExpressionType::TYPE_BOOLEAN)
    throw std::string("boolean expression expected in binary expression");

  node->right()->accept(this, lvl+2);
  if (node->right()->type().name() != ExpressionType::TYPE_BOOLEAN)
    throw std::string("boolean expression expected in binary expression");

  node->type(ExpressionType(4, ExpressionType::TYPE_BOOLEAN));
}

//protected:
void TypeValidator::processGeneralLogicalExpression(cdk::node::expression::BinaryExpression *const node, int lvl) {
  node->left()->accept(this, lvl+2);
  node->right()->accept(this, lvl+2);
  if (node->left()->type().name() != node->right()->type().name())
    throw std::string("same type expected on both sides of equality operator");
  node->type(ExpressionType(4, ExpressionType::TYPE_BOOLEAN));
}

//---------------------------------------------------------------------------
//public:
void TypeValidator::processADD(cdk::node::expression::ADD *const node, int lvl) { processIntOnlyExpression(node, lvl); }
void TypeValidator::processSUB(cdk::node::expression::SUB *const node, int lvl) { processIntOnlyExpression(node, lvl); }
void TypeValidator::processMUL(cdk::node::expression::MUL *const node, int lvl) { processIntOnlyExpression(node, lvl); }
void TypeValidator::processDIV(cdk::node::expression::DIV *const node, int lvl) { processIntOnlyExpression(node, lvl); }
void TypeValidator::processMOD(cdk::node::expression::MOD *const node, int lvl) { processIntOnlyExpression(node, lvl); }

//---------------------------------------------------------------------------

void TypeValidator::processGT(cdk::node::expression::GT  *const node, int lvl) { processScalarLogicalExpression (node, lvl);  }
void TypeValidator::processGE(cdk::node::expression::GE  *const node, int lvl) { processScalarLogicalExpression (node, lvl);  }
void TypeValidator::processLE(cdk::node::expression::LE  *const node, int lvl) { processScalarLogicalExpression (node, lvl);  }
void TypeValidator::processLT(cdk::node::expression::LT  *const node, int lvl) { processScalarLogicalExpression (node, lvl);  }
void TypeValidator::processEQ(cdk::node::expression::EQ  *const node, int lvl) { processGeneralLogicalExpression(node, lvl);  }
void TypeValidator::processNE(cdk::node::expression::NE  *const node, int lvl) { processGeneralLogicalExpression(node, lvl);  }

//===========================================================================
//---------     C O M P A C T - S P E C I F I C    N O D E S     ------------
//===========================================================================

void TypeValidator::processAssignmentNode(AssignmentNode *const node, int lvl) {
  try {
    node->lvalue()->accept(this, lvl+4);
    node->rvalue()->accept(this, lvl+4);

    if (node->lvalue()->type().name() != node->rvalue()->type().name())
      throw std::string("wrong types in assignment");
    //not an expression in Compact: node->type(node->lvalue()->type());
  }
  catch (std::string s) { throw s; }
}

//---------------------------------------------------------------------------

void TypeValidator::processReadNode(ReadNode *const node, int lvl) {
  node->argument()->accept(this, lvl+2);
  if (node->argument()->type().name() != ExpressionType::TYPE_INT)
    throw std::string("wrong type in read expression");
  //not an expression in Compact: node->type(ExpressionType(4, ExpressionType::TYPE_INT));
}

//---------------------------------------------------------------------------

void TypeValidator::processWhileNode(WhileNode *const node, int lvl) {
  try {
    node->condition()->accept(this, lvl+2);
    if (node->condition()->type().name() != ExpressionType::TYPE_BOOLEAN)
      throw std::string("boolean expression expected in 'while' cycle test");
  }
  catch (std::string s) { throw s; }
}
//---------------------------------------------------------------------------

void TypeValidator::processIfNode(IfNode *const node, int lvl) {
  try {
    node->condition()->accept(this, lvl+2);
    if (node->condition()->type().name() != ExpressionType::TYPE_BOOLEAN)
      throw std::string("boolean expression expected in 'if' test");
  }
  catch (std::string s) { throw s; }
}
//---------------------------------------------------------------------------

void TypeValidator::processIfElseNode(IfElseNode *const node, int lvl) {
  try {
    node->condition()->accept(this, lvl+2);
    if (node->condition()->type().name() != ExpressionType::TYPE_BOOLEAN)
      throw std::string("boolean expression expected in 'if' test");
  }
  catch (std::string s) { throw s; }
}

//---------------------------------------------------------------------------
//     T H E    E N D
//---------------------------------------------------------------------------
