diff --git a/elna.cabal b/elna.cabal index 564e108..c35d9b7 100644 --- a/elna.cabal +++ b/elna.cabal @@ -23,7 +23,8 @@ common warnings default-extensions: ExplicitForAll, LambdaCase, - OverloadedStrings + OverloadedStrings, + RecordWildCards default-language: GHC2021 library elna-internal diff --git a/lib/Language/Elna/TypeAnalysis.hs b/lib/Language/Elna/TypeAnalysis.hs index 0cc60a2..2544947 100644 --- a/lib/Language/Elna/TypeAnalysis.hs +++ b/lib/Language/Elna/TypeAnalysis.hs @@ -6,19 +6,27 @@ module Language.Elna.TypeAnalysis import Control.Applicative (Alternative(..)) import Control.Monad.Trans.Except (Except, runExcept, throwE) import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT) +import qualified Data.Vector as Vector import qualified Language.Elna.AST as AST import Language.Elna.Location (Identifier(..)) -import Language.Elna.SymbolTable (Info(..), SymbolTable) +import Language.Elna.SymbolTable (Info(..), ParameterInfo(..), SymbolTable) import qualified Language.Elna.SymbolTable as SymbolTable import Language.Elna.Types (Type(..), booleanType, intType) import Control.Monad.Trans.Class (MonadTrans(..)) -import Control.Monad (unless) +import Control.Monad (unless, when) +import Data.Foldable (traverse_) data Error = ArithmeticExpressionError Type | ComparisonExpressionError Type Type | UnexpectedVariableInfoError Info + | UnexpectedProcedureInfoError Info | UndefinedSymbolError Identifier + | InvalidConditionTypeError Type + | InvalidAssignmentError Type + | ExpectedLvalueError AST.Expression + | ParameterCountMismatchError Int Int + | ArgumentTypeMismatchError Type Type | ArrayIndexError Type | ArrayAccessError Type deriving (Eq, Show) @@ -50,6 +58,54 @@ typeAnalysis globalTable = either Just (const Nothing) program :: AST.Program -> TypeAnalysis () program (AST.Program _declarations) = pure () +statement :: SymbolTable -> AST.Statement -> TypeAnalysis () +statement globalTable = \case + AST.EmptyStatement -> pure () + AST.AssignmentStatement lhs rhs -> do + lhsType <- expression globalTable lhs + rhsType <- expression globalTable rhs + unless (lhsType == intType) + $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType + unless (rhsType == intType) + $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType + unless (isLvalue lhs) + $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError lhs + AST.IfStatement condition ifStatement elseStatement -> do + conditionType <- expression globalTable condition + unless (conditionType == booleanType) + $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType + statement globalTable ifStatement + maybe (pure ()) (statement globalTable) elseStatement + AST.WhileStatement condition whileStatement -> do + conditionType <- expression globalTable condition + unless (conditionType == booleanType) + $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType + statement globalTable whileStatement + AST.CompoundStatement statements -> traverse_ (statement globalTable) statements + AST.CallStatement procedureName arguments -> + case SymbolTable.lookup procedureName globalTable of + Just (ProcedureInfo _ parameters) + | parametersLength <- Vector.length parameters + , argumentsLength <- length arguments + , Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE + $ ParameterCountMismatchError parametersLength argumentsLength + | otherwise -> traverse_ (uncurry checkArgument) + $ Vector.zip parameters (Vector.fromList arguments) + Just anotherInfo -> TypeAnalysis $ lift $ throwE + $ UnexpectedVariableInfoError anotherInfo + Nothing -> TypeAnalysis $ lift $ throwE + $ UndefinedSymbolError procedureName + where + checkArgument ParameterInfo{..} argument = do + argumentType <- expression globalTable argument + unless (argumentType == type') + $ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType + when (isReferenceParameter && not (isLvalue argument)) + $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument + isLvalue (AST.ArrayExpression arrayExpression _) = isLvalue arrayExpression + isLvalue (AST.VariableExpression _) = True + isLvalue _ = False + expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type expression globalTable = \case AST.VariableExpression identifier -> do