Access multidimensional arrays

This commit is contained in:
Eugen Wissner 2024-08-15 20:13:56 +02:00
parent f78592378a
commit d405072dbf
Signed by: belka
GPG Key ID: A27FDC1E8EE902C0
5 changed files with 241 additions and 104 deletions

View File

@ -1,5 +1,7 @@
module Language.Elna.AST module Language.Elna.AST
( Declaration(..) ( VariableAccess(..)
, Condition(..)
, Declaration(..)
, Expression(..) , Expression(..)
, Identifier(..) , Identifier(..)
, Literal(..) , Literal(..)
@ -44,21 +46,25 @@ instance Show Literal
| boolean = "true" | boolean = "true"
| otherwise = "false" | otherwise = "false"
data VariableAccess
= VariableAccess Identifier
| ArrayAccess VariableAccess Expression
deriving Eq
instance Show VariableAccess
where
show (VariableAccess variableName) = show variableName
show (ArrayAccess arrayAccess elementIndex) =
concat [show arrayAccess, "[", show elementIndex, "]"]
data Expression data Expression
= VariableExpression Identifier = VariableExpression VariableAccess
| LiteralExpression Literal | LiteralExpression Literal
| NegationExpression Expression | NegationExpression Expression
| SumExpression Expression Expression | SumExpression Expression Expression
| SubtractionExpression Expression Expression | SubtractionExpression Expression Expression
| ProductExpression Expression Expression | ProductExpression Expression Expression
| DivisionExpression Expression Expression | DivisionExpression Expression Expression
| EqualExpression Expression Expression
| NonEqualExpression Expression Expression
| LessExpression Expression Expression
| GreaterExpression Expression Expression
| LessOrEqualExpression Expression Expression
| GreaterOrEqualExpression Expression Expression
| ArrayExpression Expression Expression
deriving Eq deriving Eq
instance Show Expression instance Show Expression
@ -70,20 +76,30 @@ instance Show Expression
show (SubtractionExpression lhs rhs) = concat [show lhs, " - ", show rhs] show (SubtractionExpression lhs rhs) = concat [show lhs, " - ", show rhs]
show (ProductExpression lhs rhs) = concat [show lhs, " * ", show rhs] show (ProductExpression lhs rhs) = concat [show lhs, " * ", show rhs]
show (DivisionExpression lhs rhs) = concat [show lhs, " / ", show rhs] show (DivisionExpression lhs rhs) = concat [show lhs, " / ", show rhs]
show (EqualExpression lhs rhs) = concat [show lhs, " = ", show rhs]
show (NonEqualExpression lhs rhs) = concat [show lhs, " # ", show rhs] data Condition
show (LessExpression lhs rhs) = concat [show lhs, " < ", show rhs] = EqualCondition Expression Expression
show (GreaterExpression lhs rhs) = concat [show lhs, " > ", show rhs] | NonEqualCondition Expression Expression
show (LessOrEqualExpression lhs rhs) = concat [show lhs, " <= ", show rhs] | LessCondition Expression Expression
show (GreaterOrEqualExpression lhs rhs) = concat [show lhs, " >= ", show rhs] | GreaterCondition Expression Expression
show (ArrayExpression arrayExpression indexExpression) = | LessOrEqualCondition Expression Expression
concat [show arrayExpression, "[", show indexExpression, "]"] | GreaterOrEqualCondition Expression Expression
deriving Eq
instance Show Condition
where
show (EqualCondition lhs rhs) = concat [show lhs, " = ", show rhs]
show (NonEqualCondition lhs rhs) = concat [show lhs, " # ", show rhs]
show (LessCondition lhs rhs) = concat [show lhs, " < ", show rhs]
show (GreaterCondition lhs rhs) = concat [show lhs, " > ", show rhs]
show (LessOrEqualCondition lhs rhs) = concat [show lhs, " <= ", show rhs]
show (GreaterOrEqualCondition lhs rhs) = concat [show lhs, " >= ", show rhs]
data Statement data Statement
= EmptyStatement = EmptyStatement
| AssignmentStatement Expression Expression | AssignmentStatement VariableAccess Expression
| IfStatement Expression Statement (Maybe Statement) | IfStatement Condition Statement (Maybe Statement)
| WhileStatement Expression Statement | WhileStatement Condition Statement
| CompoundStatement [Statement] | CompoundStatement [Statement]
| CallStatement Identifier [Expression] | CallStatement Identifier [Expression]
deriving Eq deriving Eq

View File

@ -3,28 +3,37 @@ module Language.Elna.Intermediate
, Operand(..) , Operand(..)
, Quadruple(..) , Quadruple(..)
, Variable(..) , Variable(..)
, intermediate
) where ) where
import Data.Int (Int32) import Data.Int (Int32)
import Data.Word (Word32) import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.Text (Text) import Data.Text (Text)
import Data.Word (Word32)
import Data.Vector (Vector)
import qualified Data.Vector as Vector
import qualified Language.Elna.AST as AST
import Language.Elna.Types (Type(..))
import Language.Elna.SymbolTable (SymbolTable, Info(..))
import qualified Language.Elna.SymbolTable as SymbolTable
data Operand data Operand
= VariableOperand Text = VariableOperand Variable
| IntOperand Int32 | IntOperand Int32
deriving (Eq, Show) deriving (Eq, Show)
newtype Label = Label Text newtype Label = Label Text
deriving (Eq, Show) deriving (Eq, Show)
newtype Variable = Variable Text data Variable = Variable Text | TempVariable
deriving (Eq, Show) deriving (Eq, Show)
data Quadruple data Quadruple
= StartQuadruple = StartQuadruple
| GoToQuadruple Label | GoToQuadruple Label
| AssignQuadruple Operand Variable | AssignQuadruple Operand Variable
| ArrayQuadruple Variable Word32 Variable | ArrayQuadruple Variable Operand Variable
| ArrayAssignQuadruple Operand Word32 Variable | ArrayAssignQuadruple Operand Word32 Variable
| AddQuadruple Operand Operand Variable | AddQuadruple Operand Operand Variable
| SubtractionQuadruple Operand Operand Variable | SubtractionQuadruple Operand Operand Variable
@ -42,3 +51,91 @@ data Quadruple
| CallQuadruple Variable Word32 | CallQuadruple Variable Word32
| StopQuadruple | StopQuadruple
deriving (Eq, Show) deriving (Eq, Show)
intermediate :: SymbolTable -> AST.Program -> HashMap AST.Identifier (Vector Quadruple)
intermediate globalTable (AST.Program declarations) =
foldr go HashMap.empty declarations
where
go (AST.TypeDefinition _ _) accumulator = accumulator
go (AST.ProcedureDefinition procedureName _ _ statements) accumulator =
let translatedStatements
= Vector.cons StartQuadruple
$ flip Vector.snoc StopQuadruple
$ foldMap (statement globalTable) statements
in HashMap.insert procedureName translatedStatements accumulator
statement :: SymbolTable -> AST.Statement -> Vector Quadruple
statement _ AST.EmptyStatement = mempty
statement globalTable (AST.CompoundStatement statements) =
foldMap (statement globalTable) statements
variableAccess
:: SymbolTable
-> AST.VariableAccess
-> Maybe Operand
-> Type
-> Vector Quadruple
-> (AST.Identifier, Maybe Operand, Vector Quadruple)
variableAccess _ (AST.VariableAccess identifier) accumulatedIndex _ accumulatedStatements =
(identifier, accumulatedIndex, accumulatedStatements)
variableAccess localTable (AST.ArrayAccess access1 index1) Nothing (ArrayType _ baseType) _ =
let (indexPlace, statements) = expression localTable index1
in variableAccess localTable access1 (Just indexPlace) baseType statements
variableAccess localTable (AST.ArrayAccess arrayAccess' arrayIndex) (Just baseIndex) (ArrayType arraySize baseType) statements =
let (indexPlace, statements') = expression localTable arrayIndex
resultVariable = TempVariable
resultOperand = VariableOperand resultVariable
indexCalculation = Vector.fromList
[ ProductQuadruple (IntOperand $ fromIntegral arraySize) baseIndex resultVariable
, AddQuadruple indexPlace resultOperand resultVariable
]
in variableAccess localTable arrayAccess' (Just resultOperand) baseType
$ statements <> indexCalculation <> statements'
variableAccess _ _ _ _ _ = error "Array access operator doesn't match the type."
variableType :: AST.VariableAccess -> SymbolTable -> Type
variableType (AST.VariableAccess identifier) symbolTable
| Just (TypeInfo type') <- SymbolTable.lookup identifier symbolTable = type'
| otherwise = error "Undefined type."
variableType (AST.ArrayAccess arrayAccess' _) symbolTable =
variableType arrayAccess' symbolTable
expression :: SymbolTable -> AST.Expression -> (Operand, Vector Quadruple)
expression localTable = \case
(AST.VariableExpression variableExpression) ->
let variableType' = variableType variableExpression localTable
in case variableAccess localTable variableExpression Nothing variableType' mempty of
(AST.Identifier identifier, Nothing, statements) ->
(VariableOperand (Variable identifier), statements)
(AST.Identifier identifier, Just operand, statements) ->
let arrayAddress = TempVariable
arrayStatement = ArrayQuadruple (Variable identifier) operand arrayAddress
in (VariableOperand arrayAddress, Vector.snoc statements arrayStatement)
(AST.LiteralExpression literal') -> (literal literal', mempty)
(AST.NegationExpression negation) ->
let (operand, statements) = expression localTable negation
tempVariable = TempVariable
negationQuadruple = NegationQuadruple operand tempVariable
in (VariableOperand tempVariable, Vector.snoc statements negationQuadruple)
(AST.SumExpression lhs rhs) -> binaryExpression AddQuadruple lhs rhs
(AST.SubtractionExpression lhs rhs) ->
binaryExpression SubtractionQuadruple lhs rhs
(AST.ProductExpression lhs rhs) ->
binaryExpression ProductQuadruple lhs rhs
(AST.DivisionExpression lhs rhs) ->
binaryExpression DivisionQuadruple lhs rhs
where
binaryExpression f lhs rhs =
let (lhsOperand, lhsStatements) = expression localTable lhs
(rhsOperand, rhsStatements) = expression localTable rhs
tempVariable = TempVariable
newQuadruple = f lhsOperand rhsOperand tempVariable
in (VariableOperand tempVariable, Vector.snoc (lhsStatements <> rhsStatements) newQuadruple)
literal :: AST.Literal -> Operand
literal (AST.IntegerLiteral integer) = IntOperand integer
literal (AST.HexadecimalLiteral integer) = IntOperand integer
literal (AST.CharacterLiteral character) = IntOperand $ fromIntegral character
literal (AST.BooleanLiteral boolean)
| boolean = IntOperand 1
| otherwise = IntOperand 0

View File

@ -83,23 +83,23 @@ declaration globalTable (AST.ProcedureDefinition identifier parameters variables
statement :: SymbolTable -> AST.Statement -> NameAnalysis () statement :: SymbolTable -> AST.Statement -> NameAnalysis ()
statement _ AST.EmptyStatement = pure () statement _ AST.EmptyStatement = pure ()
statement globalTable (AST.AssignmentStatement lvalue rvalue) statement globalTable (AST.AssignmentStatement lvalue rvalue)
= expression globalTable lvalue = variableAccess globalTable lvalue
>> expression globalTable rvalue >> expression globalTable rvalue
statement globalTable (AST.IfStatement condition ifStatement elseStatement) statement globalTable (AST.IfStatement ifCondition ifStatement elseStatement)
= expression globalTable condition = condition globalTable ifCondition
>> statement globalTable ifStatement >> statement globalTable ifStatement
>> maybe (pure ()) (statement globalTable) elseStatement >> maybe (pure ()) (statement globalTable) elseStatement
statement globalTable (AST.WhileStatement condition loop) statement globalTable (AST.WhileStatement whileCondition loop)
= expression globalTable condition = condition globalTable whileCondition
>> statement globalTable loop >> statement globalTable loop
statement globalTable (AST.CompoundStatement statements) = statement globalTable (AST.CompoundStatement statements) =
traverse_ (statement globalTable) statements traverse_ (statement globalTable) statements
statement globalTable (AST.CallStatement name arguments) statement globalTable (AST.CallStatement name arguments)
= checkSymbol name globalTable = checkSymbol globalTable name
>> traverse_ (expression globalTable) arguments >> traverse_ (expression globalTable) arguments
checkSymbol :: Identifier -> SymbolTable -> NameAnalysis () checkSymbol :: SymbolTable -> Identifier -> NameAnalysis ()
checkSymbol identifier globalTable = checkSymbol globalTable identifier =
let undefinedSymbolError = NameAnalysis let undefinedSymbolError = NameAnalysis
$ lift $ lift
$ throwE $ throwE
@ -109,8 +109,8 @@ checkSymbol identifier globalTable =
>>= (flip unless undefinedSymbolError . (isDefined ||)) >>= (flip unless undefinedSymbolError . (isDefined ||))
expression :: SymbolTable -> AST.Expression -> NameAnalysis () expression :: SymbolTable -> AST.Expression -> NameAnalysis ()
expression globalTable (AST.VariableExpression identifier) = expression globalTable (AST.VariableExpression variableExpression) =
checkSymbol identifier globalTable variableAccess globalTable variableExpression
expression _ (AST.LiteralExpression _) = pure () expression _ (AST.LiteralExpression _) = pure ()
expression globalTable (AST.NegationExpression negation) = expression globalTable (AST.NegationExpression negation) =
expression globalTable negation expression globalTable negation
@ -126,27 +126,33 @@ expression globalTable (AST.ProductExpression lhs rhs)
expression globalTable (AST.DivisionExpression lhs rhs) expression globalTable (AST.DivisionExpression lhs rhs)
= expression globalTable lhs = expression globalTable lhs
>> expression globalTable rhs >> expression globalTable rhs
expression globalTable (AST.EqualExpression lhs rhs)
= expression globalTable lhs variableAccess :: SymbolTable -> AST.VariableAccess -> NameAnalysis ()
>> expression globalTable rhs variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression)
expression globalTable (AST.NonEqualExpression lhs rhs) = variableAccess globalTable arrayExpression
= expression globalTable lhs
>> expression globalTable rhs
expression globalTable (AST.LessExpression lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
expression globalTable (AST.GreaterExpression lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
expression globalTable (AST.LessOrEqualExpression lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
expression globalTable (AST.GreaterOrEqualExpression lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
expression globalTable (AST.ArrayExpression arrayExpression indexExpression)
= expression globalTable arrayExpression
>> expression globalTable indexExpression >> expression globalTable indexExpression
variableAccess globalTable (AST.VariableAccess identifier) =
checkSymbol globalTable identifier
condition :: SymbolTable -> AST.Condition -> NameAnalysis ()
condition globalTable (AST.EqualCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
condition globalTable (AST.NonEqualCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
condition globalTable (AST.LessCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
condition globalTable (AST.GreaterCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
condition globalTable (AST.LessOrEqualCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
condition globalTable (AST.GreaterOrEqualCondition lhs rhs)
= expression globalTable lhs
>> expression globalTable rhs
enter :: Identifier -> Info -> SymbolTable -> NameAnalysis SymbolTable enter :: Identifier -> Info -> SymbolTable -> NameAnalysis SymbolTable
enter identifier info table enter identifier info table

View File

@ -9,7 +9,9 @@ import Data.Text (Text)
import qualified Data.Text as Text import qualified Data.Text as Text
import Data.Void (Void) import Data.Void (Void)
import Language.Elna.AST import Language.Elna.AST
( Declaration(..) ( VariableAccess(..)
, Condition(..)
, Declaration(..)
, Expression(..) , Expression(..)
, Identifier(..) , Identifier(..)
, Literal(..) , Literal(..)
@ -21,12 +23,12 @@ import Language.Elna.AST
) )
import Text.Megaparsec import Text.Megaparsec
( Parsec ( Parsec
, MonadParsec(..)
, (<?>) , (<?>)
, optional , optional
, between , between
, sepBy , sepBy
, choice , choice
, MonadParsec(..)
) )
import Text.Megaparsec.Char import Text.Megaparsec.Char
( alphaNumChar ( alphaNumChar
@ -124,16 +126,20 @@ termP :: Parser Expression
termP = choice termP = choice
[ parensP expressionP [ parensP expressionP
, LiteralExpression <$> literalP , LiteralExpression <$> literalP
, VariableExpression <$> identifierP , VariableExpression <$> variableAccessP
] ]
variableAccessP :: Parser VariableAccess
variableAccessP = do
identifier <- identifierP
indices <- many $ bracketsP expressionP
pure $ foldr (flip ArrayAccess) (VariableAccess identifier) indices
operatorTable :: [[Operator Parser Expression]] operatorTable :: [[Operator Parser Expression]]
operatorTable = operatorTable =
[ [Postfix (flip ArrayExpression <$> bracketsP expressionP)] [ unaryOperator
, unaryOperator
, factorOperator , factorOperator
, termOperator , termOperator
, comparisonOperator
] ]
where where
unaryOperator = unaryOperator =
@ -148,20 +154,27 @@ operatorTable =
[ binary "+" SumExpression [ binary "+" SumExpression
, binary "-" SubtractionExpression , binary "-" SubtractionExpression
] ]
comparisonOperator =
[ binary "<" LessExpression
, binary "<=" LessOrEqualExpression
, binary ">" GreaterExpression
, binary ">=" GreaterOrEqualExpression
, binary "=" EqualExpression
, binary "#" NonEqualExpression
]
prefix name f = Prefix (f <$ symbol name) prefix name f = Prefix (f <$ symbol name)
binary name f = InfixL (f <$ symbol name) binary name f = InfixL (f <$ symbol name)
expressionP :: Parser Expression expressionP :: Parser Expression
expressionP = makeExprParser termP operatorTable expressionP = makeExprParser termP operatorTable
conditionP :: Parser Condition
conditionP = do
lhs <- expressionP
conditionCons <- choice comparisonOperator
conditionCons lhs <$> expressionP
where
comparisonOperator =
[ symbol "<" >> pure LessCondition
, symbol "<=" >> pure LessOrEqualCondition
, symbol ">" >> pure GreaterCondition
, symbol ">=" >> pure GreaterOrEqualCondition
, symbol "=" >> pure EqualCondition
, symbol "#" >> pure NonEqualCondition
]
statementP :: Parser Statement statementP :: Parser Statement
statementP statementP
= EmptyStatement <$ semicolonP = EmptyStatement <$ semicolonP
@ -173,18 +186,18 @@ statementP
<?> "statement" <?> "statement"
where where
ifElseP = IfStatement ifElseP = IfStatement
<$> (symbol "if" *> parensP expressionP) <$> (symbol "if" *> parensP conditionP)
<*> statementP <*> statementP
<*> optional (symbol "else" *> statementP) <*> optional (symbol "else" *> statementP)
whileP = WhileStatement whileP = WhileStatement
<$> (symbol "while" *> parensP expressionP) <$> (symbol "while" *> parensP conditionP)
<*> statementP <*> statementP
callP = CallStatement callP = CallStatement
<$> identifierP <$> identifierP
<*> parensP (sepBy expressionP commaP) <*> parensP (sepBy expressionP commaP)
<* semicolonP <* semicolonP
assignmentP = AssignmentStatement assignmentP = AssignmentStatement
<$> expressionP <$> variableAccessP
<* symbol ":=" <* symbol ":="
<*> expressionP <*> expressionP
<* semicolonP <* semicolonP

View File

@ -76,22 +76,20 @@ statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
statement globalTable = \case statement globalTable = \case
AST.EmptyStatement -> pure () AST.EmptyStatement -> pure ()
AST.AssignmentStatement lhs rhs -> do AST.AssignmentStatement lhs rhs -> do
lhsType <- expression globalTable lhs lhsType <- variableAccess globalTable lhs
rhsType <- expression globalTable rhs rhsType <- expression globalTable rhs
unless (lhsType == intType) unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
unless (rhsType == intType) unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
unless (isLvalue lhs) AST.IfStatement ifCondition ifStatement elseStatement -> do
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError lhs conditionType <- condition globalTable ifCondition
AST.IfStatement condition ifStatement elseStatement -> do
conditionType <- expression globalTable condition
unless (conditionType == booleanType) unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable ifStatement statement globalTable ifStatement
maybe (pure ()) (statement globalTable) elseStatement maybe (pure ()) (statement globalTable) elseStatement
AST.WhileStatement condition whileStatement -> do AST.WhileStatement whileCondition whileStatement -> do
conditionType <- expression globalTable condition conditionType <- condition globalTable whileCondition
unless (conditionType == booleanType) unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType $ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable whileStatement statement globalTable whileStatement
@ -116,13 +114,11 @@ statement globalTable = \case
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType $ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
when (isReferenceParameter && not (isLvalue argument)) when (isReferenceParameter && not (isLvalue argument))
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument $ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
isLvalue (AST.ArrayExpression arrayExpression _) = isLvalue arrayExpression
isLvalue (AST.VariableExpression _) = True isLvalue (AST.VariableExpression _) = True
isLvalue _ = False isLvalue _ = False
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type variableAccess :: SymbolTable -> AST.VariableAccess -> TypeAnalysis Type
expression globalTable = \case variableAccess globalTable (AST.VariableAccess identifier) = do
AST.VariableExpression identifier -> do
localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier
case localLookup <|> SymbolTable.lookup identifier globalTable of case localLookup <|> SymbolTable.lookup identifier globalTable of
Just (VariableInfo _ variableType) -> pure variableType Just (VariableInfo _ variableType) -> pure variableType
@ -130,6 +126,20 @@ expression globalTable = \case
$ UnexpectedVariableInfoError anotherInfo $ UnexpectedVariableInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError identifier $ UndefinedSymbolError identifier
variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = do
arrayType <- variableAccess globalTable arrayExpression
indexType <- expression globalTable indexExpression
unless (indexType == intType)
$ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType
case arrayType of
ArrayType _ baseType -> pure baseType
nonArrayType -> TypeAnalysis $ lift $ throwE
$ ArrayAccessError nonArrayType
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
expression globalTable = \case
AST.VariableExpression variableExpression -> do
variableAccess globalTable variableExpression
AST.LiteralExpression literal' -> literal literal' AST.LiteralExpression literal' -> literal literal'
AST.NegationExpression negation -> do AST.NegationExpression negation -> do
operandType <- expression globalTable negation operandType <- expression globalTable negation
@ -140,21 +150,6 @@ expression globalTable = \case
AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs
AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs
AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs
AST.EqualExpression lhs rhs -> comparisonExpression lhs rhs
AST.NonEqualExpression lhs rhs -> comparisonExpression lhs rhs
AST.LessExpression lhs rhs -> comparisonExpression lhs rhs
AST.GreaterExpression lhs rhs -> comparisonExpression lhs rhs
AST.LessOrEqualExpression lhs rhs -> comparisonExpression lhs rhs
AST.GreaterOrEqualExpression lhs rhs -> comparisonExpression lhs rhs
AST.ArrayExpression arrayExpression indexExpression -> do
arrayType <- expression globalTable arrayExpression
indexType <- expression globalTable indexExpression
unless (indexType == intType)
$ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType
case arrayType of
ArrayType _ baseType -> pure baseType
nonArrayType -> TypeAnalysis $ lift $ throwE
$ ArrayAccessError nonArrayType
where where
arithmeticExpression lhs rhs = do arithmeticExpression lhs rhs = do
lhsType <- expression globalTable lhs lhsType <- expression globalTable lhs
@ -164,10 +159,20 @@ expression globalTable = \case
unless (rhsType == intType) unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType $ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType
pure intType pure intType
condition :: SymbolTable -> AST.Condition -> TypeAnalysis Type
condition globalTable = \case
AST.EqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.NonEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
where
comparisonExpression lhs rhs = do comparisonExpression lhs rhs = do
lhsType <- expression globalTable lhs lhsType <- expression globalTable lhs
rhsType <- expression globalTable rhs rhsType <- expression globalTable rhs
if lhsType == intType && rhsType ==intType if lhsType == intType && rhsType == intType
then pure booleanType then pure booleanType
else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType