Check argument list length

This commit is contained in:
2024-10-31 22:19:48 +01:00
parent 43882a3a06
commit e2d4b76c0b
3 changed files with 87 additions and 65 deletions

View File

@ -3,42 +3,61 @@ module Language.Elna.Frontend.TypeAnalysis
, -- Error(..)
) where
import Control.Monad (unless)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Except (Except, runExcept, throwE)
import Control.Monad.Trans.Reader (ReaderT, runReaderT, withReaderT, ask)
import Data.Foldable (traverse_)
import qualified Data.Vector as Vector
import qualified Language.Elna.Frontend.AST as AST
import Language.Elna.Frontend.SymbolTable ({-Info(..), ParameterInfo(..), -}SymbolTable)
import Language.Elna.Frontend.SymbolTable (Info(..), {-ParameterInfo(..), -}SymbolTable)
import qualified Language.Elna.Frontend.SymbolTable as SymbolTable
import Language.Elna.Frontend.Types (Type(..), booleanType, intType)
import Language.Elna.Location (Identifier(..))
typeAnalysis :: SymbolTable -> AST.Program -> () -- Maybe Error
typeAnalysis _globalTable = const () {- either Just (const Nothing)
typeAnalysis :: SymbolTable -> AST.Program -> Maybe Error
typeAnalysis globalTable = either Just (const Nothing)
. runExcept
. flip runReaderT globalTable
. runTypeAnalysis
. program -}
. program
{-
import Control.Applicative (Alternative(..))
import Control.Monad.Trans.Except (Except, runExcept, throwE)
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT, withReaderT, ask)
import qualified Data.Vector as Vector
import Language.Elna.Location (Identifier(..))
import qualified Language.Elna.SymbolTable as SymbolTable
import Language.Elna.Types (Type(..), booleanType, intType)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad (unless, when)
import Data.Foldable (traverse_)
-}
data Error
= ArithmeticExpressionError Type
| ComparisonExpressionError Type Type
| UnexpectedVariableInfoError Info
| UnexpectedProcedureInfoError Info
= UnexpectedProcedureInfoError Info
| UndefinedSymbolError Identifier
| InvalidConditionTypeError Type
| InvalidAssignmentError Type
| ExpectedLvalueError AST.Expression
| ParameterCountMismatchError Int Int
| UnexpectedVariableInfoError Info
| ArithmeticExpressionError Type
| ComparisonExpressionError Type Type
| InvalidConditionTypeError Type
{- | InvalidAssignmentError Type
| ExpectedLvalueError AST.Expression
| ArgumentTypeMismatchError Type Type
| ArrayIndexError Type
| ArrayAccessError Type
deriving (Eq, Show)
| ArrayAccessError Type -}
deriving Eq
instance Show Error
where
show (UnexpectedProcedureInfoError info) =
"Expected to encounter a procedure, got: " <> show info
show (UndefinedSymbolError identifier) =
concat ["Symbol \"", show identifier, "\" is not defined"]
show (ParameterCountMismatchError parameterCount argumentCount)
= "The function was expected to receive " <> show argumentCount
<> " arguments, but got " <> show parameterCount
show (UnexpectedVariableInfoError info) =
"Expected to encounter a variable, got: " <> show info
show (ArithmeticExpressionError got) =
"Expected an arithmetic expression to be an integer, got: " <> show got
show (ComparisonExpressionError lhs rhs)
= "Expected an arithmetic expression to be an integer, got \""
<> show lhs <> "\" and \"" <> show rhs <> "\""
show (InvalidConditionTypeError got) =
"Expected a condition to be a boolean, got: " <> show got
newtype TypeAnalysis a = TypeAnalysis
{ runTypeAnalysis :: ReaderT SymbolTable (Except Error) a
@ -61,7 +80,7 @@ program :: AST.Program -> TypeAnalysis ()
program (AST.Program declarations) = traverse_ declaration declarations
declaration :: AST.Declaration -> TypeAnalysis ()
declaration (AST.ProcedureDefinition procedureName _ _ body) = do
declaration (AST.ProcedureDeclaration procedureName _ _ body) = do
globalTable <- TypeAnalysis ask
case SymbolTable.lookup procedureName globalTable of
Just (ProcedureInfo localTable _) -> TypeAnalysis
@ -72,29 +91,29 @@ declaration (AST.ProcedureDefinition procedureName _ _ body) = do
$ UnexpectedProcedureInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
declaration _ = pure ()
declaration (AST.TypeDefinition _ _) = pure ()
statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
statement globalTable = \case
AST.EmptyStatement -> pure ()
AST.AssignmentStatement lhs rhs -> do
{- AST.AssignmentStatement lhs rhs -> do
lhsType <- variableAccess globalTable lhs
rhsType <- expression globalTable rhs
unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
AST.WhileStatement whileCondition whileStatement -> do
conditionType <- condition globalTable whileCondition
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable whileStatement -}
AST.IfStatement ifCondition ifStatement elseStatement -> do
conditionType <- condition globalTable ifCondition
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable ifStatement
maybe (pure ()) (statement globalTable) elseStatement
AST.WhileStatement whileCondition whileStatement -> do
conditionType <- condition globalTable whileCondition
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable whileStatement
AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
AST.CallStatement procedureName arguments ->
case SymbolTable.lookup procedureName globalTable of
@ -110,7 +129,7 @@ statement globalTable = \case
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
where
checkArgument ParameterInfo{..} argument = do
checkArgument SymbolTable.ParameterInfo{} _argument = pure () {-
argumentType <- expression globalTable argument
unless (argumentType == type')
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
@ -137,11 +156,11 @@ variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = d
ArrayType _ baseType -> pure baseType
nonArrayType -> TypeAnalysis $ lift $ throwE
$ ArrayAccessError nonArrayType
-}
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
expression globalTable = \case
AST.VariableExpression variableExpression -> do
variableAccess globalTable variableExpression
{- AST.VariableExpression variableExpression -> do
variableAccess globalTable variableExpression -}
AST.LiteralExpression literal' -> literal literal'
AST.NegationExpression negation -> do
operandType <- expression globalTable negation
@ -179,8 +198,6 @@ condition globalTable = \case
else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType
literal :: AST.Literal -> TypeAnalysis Type
literal (AST.IntegerLiteral _) = pure intType
literal (AST.DecimalLiteral _) = pure intType
literal (AST.HexadecimalLiteral _) = pure intType
literal (AST.CharacterLiteral _) = pure intType
literal (AST.BooleanLiteral _) = pure booleanType
-}