module Language.Elna.TypeAnalysis ( typeAnalysis , -- Error(..) ) where import qualified Language.Elna.AST as AST import Language.Elna.SymbolTable ({-Info(..), ParameterInfo(..), -}SymbolTable) typeAnalysis :: SymbolTable -> AST.Program -> () -- Maybe Error typeAnalysis _globalTable = const () {- either Just (const Nothing) . runExcept . flip runReaderT globalTable . runTypeAnalysis . program -} {- import Control.Applicative (Alternative(..)) import Control.Monad.Trans.Except (Except, runExcept, throwE) import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT, withReaderT, ask) import qualified Data.Vector as Vector import Language.Elna.Location (Identifier(..)) import qualified Language.Elna.SymbolTable as SymbolTable import Language.Elna.Types (Type(..), booleanType, intType) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad (unless, when) import Data.Foldable (traverse_) data Error = ArithmeticExpressionError Type | ComparisonExpressionError Type Type | UnexpectedVariableInfoError Info | UnexpectedProcedureInfoError Info | UndefinedSymbolError Identifier | InvalidConditionTypeError Type | InvalidAssignmentError Type | ExpectedLvalueError AST.Expression | ParameterCountMismatchError Int Int | ArgumentTypeMismatchError Type Type | ArrayIndexError Type | ArrayAccessError Type deriving (Eq, Show) newtype TypeAnalysis a = TypeAnalysis { runTypeAnalysis :: ReaderT SymbolTable (Except Error) a } instance Functor TypeAnalysis where fmap f (TypeAnalysis x) = TypeAnalysis $ f <$> x instance Applicative TypeAnalysis where pure = TypeAnalysis . pure (TypeAnalysis f) <*> (TypeAnalysis x) = TypeAnalysis $ f <*> x instance Monad TypeAnalysis where (TypeAnalysis x) >>= f = TypeAnalysis $ x >>= (runTypeAnalysis . f) program :: AST.Program -> TypeAnalysis () program (AST.Program declarations) = traverse_ declaration declarations declaration :: AST.Declaration -> TypeAnalysis () declaration (AST.ProcedureDefinition procedureName _ _ body) = do globalTable <- TypeAnalysis ask case SymbolTable.lookup procedureName globalTable of Just (ProcedureInfo localTable _) -> TypeAnalysis $ withReaderT (const localTable) $ runTypeAnalysis $ traverse_ (statement globalTable) body Just anotherInfo -> TypeAnalysis $ lift $ throwE $ UnexpectedProcedureInfoError anotherInfo Nothing -> TypeAnalysis $ lift $ throwE $ UndefinedSymbolError procedureName declaration _ = pure () statement :: SymbolTable -> AST.Statement -> TypeAnalysis () statement globalTable = \case AST.EmptyStatement -> pure () AST.AssignmentStatement lhs rhs -> do lhsType <- variableAccess globalTable lhs rhsType <- expression globalTable rhs unless (lhsType == intType) $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType unless (rhsType == intType) $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType AST.IfStatement ifCondition ifStatement elseStatement -> do conditionType <- condition globalTable ifCondition unless (conditionType == booleanType) $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType statement globalTable ifStatement maybe (pure ()) (statement globalTable) elseStatement AST.WhileStatement whileCondition whileStatement -> do conditionType <- condition globalTable whileCondition unless (conditionType == booleanType) $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType statement globalTable whileStatement AST.CompoundStatement statements -> traverse_ (statement globalTable) statements AST.CallStatement procedureName arguments -> case SymbolTable.lookup procedureName globalTable of Just (ProcedureInfo _ parameters) | parametersLength <- Vector.length parameters , argumentsLength <- length arguments , Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE $ ParameterCountMismatchError parametersLength argumentsLength | otherwise -> traverse_ (uncurry checkArgument) $ Vector.zip parameters (Vector.fromList arguments) Just anotherInfo -> TypeAnalysis $ lift $ throwE $ UnexpectedVariableInfoError anotherInfo Nothing -> TypeAnalysis $ lift $ throwE $ UndefinedSymbolError procedureName where checkArgument ParameterInfo{..} argument = do argumentType <- expression globalTable argument unless (argumentType == type') $ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType when (isReferenceParameter && not (isLvalue argument)) $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument isLvalue (AST.VariableExpression _) = True isLvalue _ = False variableAccess :: SymbolTable -> AST.VariableAccess -> TypeAnalysis Type variableAccess globalTable (AST.VariableAccess identifier) = do localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier case localLookup <|> SymbolTable.lookup identifier globalTable of Just (VariableInfo _ variableType) -> pure variableType Just anotherInfo -> TypeAnalysis $ lift $ throwE $ UnexpectedVariableInfoError anotherInfo Nothing -> TypeAnalysis $ lift $ throwE $ UndefinedSymbolError identifier variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = do arrayType <- variableAccess globalTable arrayExpression indexType <- expression globalTable indexExpression unless (indexType == intType) $ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType case arrayType of ArrayType _ baseType -> pure baseType nonArrayType -> TypeAnalysis $ lift $ throwE $ ArrayAccessError nonArrayType expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type expression globalTable = \case AST.VariableExpression variableExpression -> do variableAccess globalTable variableExpression AST.LiteralExpression literal' -> literal literal' AST.NegationExpression negation -> do operandType <- expression globalTable negation if operandType == intType then pure intType else TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError operandType AST.SumExpression lhs rhs -> arithmeticExpression lhs rhs AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs where arithmeticExpression lhs rhs = do lhsType <- expression globalTable lhs unless (lhsType == intType) $ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError lhsType rhsType <- expression globalTable rhs unless (rhsType == intType) $ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType pure intType condition :: SymbolTable -> AST.Condition -> TypeAnalysis Type condition globalTable = \case AST.EqualCondition lhs rhs -> comparisonExpression lhs rhs AST.NonEqualCondition lhs rhs -> comparisonExpression lhs rhs AST.LessCondition lhs rhs -> comparisonExpression lhs rhs AST.GreaterCondition lhs rhs -> comparisonExpression lhs rhs AST.LessOrEqualCondition lhs rhs -> comparisonExpression lhs rhs AST.GreaterOrEqualCondition lhs rhs -> comparisonExpression lhs rhs where comparisonExpression lhs rhs = do lhsType <- expression globalTable lhs rhsType <- expression globalTable rhs if lhsType == intType && rhsType == intType then pure booleanType else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType literal :: AST.Literal -> TypeAnalysis Type literal (AST.IntegerLiteral _) = pure intType literal (AST.HexadecimalLiteral _) = pure intType literal (AST.CharacterLiteral _) = pure intType literal (AST.BooleanLiteral _) = pure booleanType -}