summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--elna.cabal3
-rw-r--r--lib/Language/Elna/TypeAnalysis.hs60
2 files changed, 60 insertions, 3 deletions
diff --git a/elna.cabal b/elna.cabal
index 564e108..c35d9b7 100644
--- a/elna.cabal
+++ b/elna.cabal
@@ -23,7 +23,8 @@ common warnings
default-extensions:
ExplicitForAll,
LambdaCase,
- OverloadedStrings
+ OverloadedStrings,
+ RecordWildCards
default-language: GHC2021
library elna-internal
diff --git a/lib/Language/Elna/TypeAnalysis.hs b/lib/Language/Elna/TypeAnalysis.hs
index 0cc60a2..2544947 100644
--- a/lib/Language/Elna/TypeAnalysis.hs
+++ b/lib/Language/Elna/TypeAnalysis.hs
@@ -6,19 +6,27 @@ module Language.Elna.TypeAnalysis
import Control.Applicative (Alternative(..))
import Control.Monad.Trans.Except (Except, runExcept, throwE)
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT)
+import qualified Data.Vector as Vector
import qualified Language.Elna.AST as AST
import Language.Elna.Location (Identifier(..))
-import Language.Elna.SymbolTable (Info(..), SymbolTable)
+import Language.Elna.SymbolTable (Info(..), ParameterInfo(..), SymbolTable)
import qualified Language.Elna.SymbolTable as SymbolTable
import Language.Elna.Types (Type(..), booleanType, intType)
import Control.Monad.Trans.Class (MonadTrans(..))
-import Control.Monad (unless)
+import Control.Monad (unless, when)
+import Data.Foldable (traverse_)
data Error
= ArithmeticExpressionError Type
| ComparisonExpressionError Type Type
| UnexpectedVariableInfoError Info
+ | UnexpectedProcedureInfoError Info
| UndefinedSymbolError Identifier
+ | InvalidConditionTypeError Type
+ | InvalidAssignmentError Type
+ | ExpectedLvalueError AST.Expression
+ | ParameterCountMismatchError Int Int
+ | ArgumentTypeMismatchError Type Type
| ArrayIndexError Type
| ArrayAccessError Type
deriving (Eq, Show)
@@ -50,6 +58,54 @@ typeAnalysis globalTable = either Just (const Nothing)
program :: AST.Program -> TypeAnalysis ()
program (AST.Program _declarations) = pure ()
+statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
+statement globalTable = \case
+ AST.EmptyStatement -> pure ()
+ AST.AssignmentStatement lhs rhs -> do
+ lhsType <- expression globalTable lhs
+ rhsType <- expression globalTable rhs
+ unless (lhsType == intType)
+ $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
+ unless (rhsType == intType)
+ $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
+ unless (isLvalue lhs)
+ $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError lhs
+ AST.IfStatement condition ifStatement elseStatement -> do
+ conditionType <- expression globalTable condition
+ unless (conditionType == booleanType)
+ $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
+ statement globalTable ifStatement
+ maybe (pure ()) (statement globalTable) elseStatement
+ AST.WhileStatement condition whileStatement -> do
+ conditionType <- expression globalTable condition
+ unless (conditionType == booleanType)
+ $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
+ statement globalTable whileStatement
+ AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
+ AST.CallStatement procedureName arguments ->
+ case SymbolTable.lookup procedureName globalTable of
+ Just (ProcedureInfo _ parameters)
+ | parametersLength <- Vector.length parameters
+ , argumentsLength <- length arguments
+ , Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE
+ $ ParameterCountMismatchError parametersLength argumentsLength
+ | otherwise -> traverse_ (uncurry checkArgument)
+ $ Vector.zip parameters (Vector.fromList arguments)
+ Just anotherInfo -> TypeAnalysis $ lift $ throwE
+ $ UnexpectedVariableInfoError anotherInfo
+ Nothing -> TypeAnalysis $ lift $ throwE
+ $ UndefinedSymbolError procedureName
+ where
+ checkArgument ParameterInfo{..} argument = do
+ argumentType <- expression globalTable argument
+ unless (argumentType == type')
+ $ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
+ when (isReferenceParameter && not (isLvalue argument))
+ $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
+ isLvalue (AST.ArrayExpression arrayExpression _) = isLvalue arrayExpression
+ isLvalue (AST.VariableExpression _) = True
+ isLvalue _ = False
+
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
expression globalTable = \case
AST.VariableExpression identifier -> do