module Language.Elna.TypeAnalysis
( typeAnalysis
, -- Error(..)
) where
import qualified Language.Elna.AST as AST
import Language.Elna.SymbolTable ({-Info(..), ParameterInfo(..), -}SymbolTable)
typeAnalysis :: SymbolTable -> AST.Program -> () -- Maybe Error
typeAnalysis _globalTable = const () {- either Just (const Nothing)
. runExcept
. flip runReaderT globalTable
. runTypeAnalysis
. program -}
import Control.Applicative (Alternative(..))
import Control.Monad.Trans.Except (Except, runExcept, throwE)
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT, withReaderT, ask)
import qualified Data.Vector as Vector
import Language.Elna.Location (Identifier(..))
import qualified Language.Elna.SymbolTable as SymbolTable
import Language.Elna.Types (Type(..), booleanType, intType)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad (unless, when)
import Data.Foldable (traverse_)
data Error
= ArithmeticExpressionError Type
| ComparisonExpressionError Type Type
| UnexpectedVariableInfoError Info
| UnexpectedProcedureInfoError Info
| UndefinedSymbolError Identifier
| InvalidConditionTypeError Type
| InvalidAssignmentError Type
| ExpectedLvalueError AST.Expression
| ParameterCountMismatchError Int Int
| ArgumentTypeMismatchError Type Type
| ArrayIndexError Type
| ArrayAccessError Type
deriving (Eq, Show)
newtype TypeAnalysis a = TypeAnalysis
{ runTypeAnalysis :: ReaderT SymbolTable (Except Error) a
instance Functor TypeAnalysis
fmap f (TypeAnalysis x) = TypeAnalysis $ f <$> x
instance Applicative TypeAnalysis
pure = TypeAnalysis . pure
(TypeAnalysis f) <*> (TypeAnalysis x) = TypeAnalysis $ f <*> x
instance Monad TypeAnalysis
(TypeAnalysis x) >>= f = TypeAnalysis $ x >>= (runTypeAnalysis . f)
program :: AST.Program -> TypeAnalysis ()
program (AST.Program declarations) = traverse_ declaration declarations
declaration :: AST.Declaration -> TypeAnalysis ()
declaration (AST.ProcedureDefinition procedureName _ _ body) = do
globalTable <- TypeAnalysis ask
case SymbolTable.lookup procedureName globalTable of
Just (ProcedureInfo localTable _) -> TypeAnalysis
$ withReaderT (const localTable)
$ runTypeAnalysis
$ traverse_ (statement globalTable) body
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedProcedureInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
declaration _ = pure ()
statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
statement globalTable = \case
AST.EmptyStatement -> pure ()
AST.AssignmentStatement lhs rhs -> do
lhsType <- variableAccess globalTable lhs
rhsType <- expression globalTable rhs
unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
AST.IfStatement ifCondition ifStatement elseStatement -> do
conditionType <- condition globalTable ifCondition
2024-08-08 19:11:24 +02:00
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable ifStatement
maybe (pure ()) (statement globalTable) elseStatement
AST.WhileStatement whileCondition whileStatement -> do
conditionType <- condition globalTable whileCondition
2024-08-08 19:11:24 +02:00
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable whileStatement
AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
AST.CallStatement procedureName arguments ->
case SymbolTable.lookup procedureName globalTable of
Just (ProcedureInfo _ parameters)
| parametersLength <- Vector.length parameters
, argumentsLength <- length arguments
, Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE
$ ParameterCountMismatchError parametersLength argumentsLength
| otherwise -> traverse_ (uncurry checkArgument)
$ Vector.zip parameters (Vector.fromList arguments)
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedVariableInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
checkArgument ParameterInfo{..} argument = do
argumentType <- expression globalTable argument
unless (argumentType == type')
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
when (isReferenceParameter && not (isLvalue argument))
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
isLvalue (AST.VariableExpression _) = True
isLvalue _ = False
variableAccess :: SymbolTable -> AST.VariableAccess -> TypeAnalysis Type
variableAccess globalTable (AST.VariableAccess identifier) = do
localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier
case localLookup <|> SymbolTable.lookup identifier globalTable of
Just (VariableInfo _ variableType) -> pure variableType
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedVariableInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError identifier
variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = do
arrayType <- variableAccess globalTable arrayExpression
indexType <- expression globalTable indexExpression
unless (indexType == intType)
$ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType
case arrayType of
ArrayType _ baseType -> pure baseType
nonArrayType -> TypeAnalysis $ lift $ throwE
$ ArrayAccessError nonArrayType
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
expression globalTable = \case
AST.VariableExpression variableExpression -> do
variableAccess globalTable variableExpression
AST.LiteralExpression literal' -> literal literal'
AST.NegationExpression negation -> do
operandType <- expression globalTable negation
if operandType == intType
then pure intType
else TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError operandType
AST.SumExpression lhs rhs -> arithmeticExpression lhs rhs
AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs
AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs
AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs
arithmeticExpression lhs rhs = do
lhsType <- expression globalTable lhs
unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError lhsType
rhsType <- expression globalTable rhs
unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType
pure intType
condition :: SymbolTable -> AST.Condition -> TypeAnalysis Type
condition globalTable = \case
AST.EqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.NonEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
comparisonExpression lhs rhs = do
lhsType <- expression globalTable lhs
rhsType <- expression globalTable rhs
if lhsType == intType && rhsType == intType
then pure booleanType
else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType
literal :: AST.Literal -> TypeAnalysis Type
literal (AST.IntegerLiteral _) = pure intType
literal (AST.HexadecimalLiteral _) = pure intType
literal (AST.CharacterLiteral _) = pure intType
literal (AST.BooleanLiteral _) = pure booleanType
