Typecheck statements
This commit is contained in:
parent
573990551c
commit
18a299098c
@ -23,7 +23,8 @@ common warnings
|
|||||||
default-extensions:
|
default-extensions:
|
||||||
ExplicitForAll,
|
ExplicitForAll,
|
||||||
LambdaCase,
|
LambdaCase,
|
||||||
OverloadedStrings
|
OverloadedStrings,
|
||||||
|
RecordWildCards
|
||||||
default-language: GHC2021
|
default-language: GHC2021
|
||||||
|
|
||||||
library elna-internal
|
library elna-internal
|
||||||
|
@ -6,19 +6,27 @@ module Language.Elna.TypeAnalysis
|
|||||||
import Control.Applicative (Alternative(..))
|
import Control.Applicative (Alternative(..))
|
||||||
import Control.Monad.Trans.Except (Except, runExcept, throwE)
|
import Control.Monad.Trans.Except (Except, runExcept, throwE)
|
||||||
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT)
|
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT)
|
||||||
|
import qualified Data.Vector as Vector
|
||||||
import qualified Language.Elna.AST as AST
|
import qualified Language.Elna.AST as AST
|
||||||
import Language.Elna.Location (Identifier(..))
|
import Language.Elna.Location (Identifier(..))
|
||||||
import Language.Elna.SymbolTable (Info(..), SymbolTable)
|
import Language.Elna.SymbolTable (Info(..), ParameterInfo(..), SymbolTable)
|
||||||
import qualified Language.Elna.SymbolTable as SymbolTable
|
import qualified Language.Elna.SymbolTable as SymbolTable
|
||||||
import Language.Elna.Types (Type(..), booleanType, intType)
|
import Language.Elna.Types (Type(..), booleanType, intType)
|
||||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
import Control.Monad (unless)
|
import Control.Monad (unless, when)
|
||||||
|
import Data.Foldable (traverse_)
|
||||||
|
|
||||||
data Error
|
data Error
|
||||||
= ArithmeticExpressionError Type
|
= ArithmeticExpressionError Type
|
||||||
| ComparisonExpressionError Type Type
|
| ComparisonExpressionError Type Type
|
||||||
| UnexpectedVariableInfoError Info
|
| UnexpectedVariableInfoError Info
|
||||||
|
| UnexpectedProcedureInfoError Info
|
||||||
| UndefinedSymbolError Identifier
|
| UndefinedSymbolError Identifier
|
||||||
|
| InvalidConditionTypeError Type
|
||||||
|
| InvalidAssignmentError Type
|
||||||
|
| ExpectedLvalueError AST.Expression
|
||||||
|
| ParameterCountMismatchError Int Int
|
||||||
|
| ArgumentTypeMismatchError Type Type
|
||||||
| ArrayIndexError Type
|
| ArrayIndexError Type
|
||||||
| ArrayAccessError Type
|
| ArrayAccessError Type
|
||||||
deriving (Eq, Show)
|
deriving (Eq, Show)
|
||||||
@ -50,6 +58,54 @@ typeAnalysis globalTable = either Just (const Nothing)
|
|||||||
program :: AST.Program -> TypeAnalysis ()
|
program :: AST.Program -> TypeAnalysis ()
|
||||||
program (AST.Program _declarations) = pure ()
|
program (AST.Program _declarations) = pure ()
|
||||||
|
|
||||||
|
statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
|
||||||
|
statement globalTable = \case
|
||||||
|
AST.EmptyStatement -> pure ()
|
||||||
|
AST.AssignmentStatement lhs rhs -> do
|
||||||
|
lhsType <- expression globalTable lhs
|
||||||
|
rhsType <- expression globalTable rhs
|
||||||
|
unless (lhsType == intType)
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
|
||||||
|
unless (rhsType == intType)
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
|
||||||
|
unless (isLvalue lhs)
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError lhs
|
||||||
|
AST.IfStatement condition ifStatement elseStatement -> do
|
||||||
|
conditionType <- expression globalTable condition
|
||||||
|
unless (conditionType == booleanType)
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
|
||||||
|
statement globalTable ifStatement
|
||||||
|
maybe (pure ()) (statement globalTable) elseStatement
|
||||||
|
AST.WhileStatement condition whileStatement -> do
|
||||||
|
conditionType <- expression globalTable condition
|
||||||
|
unless (conditionType == booleanType)
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
|
||||||
|
statement globalTable whileStatement
|
||||||
|
AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
|
||||||
|
AST.CallStatement procedureName arguments ->
|
||||||
|
case SymbolTable.lookup procedureName globalTable of
|
||||||
|
Just (ProcedureInfo _ parameters)
|
||||||
|
| parametersLength <- Vector.length parameters
|
||||||
|
, argumentsLength <- length arguments
|
||||||
|
, Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE
|
||||||
|
$ ParameterCountMismatchError parametersLength argumentsLength
|
||||||
|
| otherwise -> traverse_ (uncurry checkArgument)
|
||||||
|
$ Vector.zip parameters (Vector.fromList arguments)
|
||||||
|
Just anotherInfo -> TypeAnalysis $ lift $ throwE
|
||||||
|
$ UnexpectedVariableInfoError anotherInfo
|
||||||
|
Nothing -> TypeAnalysis $ lift $ throwE
|
||||||
|
$ UndefinedSymbolError procedureName
|
||||||
|
where
|
||||||
|
checkArgument ParameterInfo{..} argument = do
|
||||||
|
argumentType <- expression globalTable argument
|
||||||
|
unless (argumentType == type')
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
|
||||||
|
when (isReferenceParameter && not (isLvalue argument))
|
||||||
|
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
|
||||||
|
isLvalue (AST.ArrayExpression arrayExpression _) = isLvalue arrayExpression
|
||||||
|
isLvalue (AST.VariableExpression _) = True
|
||||||
|
isLvalue _ = False
|
||||||
|
|
||||||
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
|
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
|
||||||
expression globalTable = \case
|
expression globalTable = \case
|
||||||
AST.VariableExpression identifier -> do
|
AST.VariableExpression identifier -> do
|
||||||
|
Loading…
Reference in New Issue
Block a user