elna/lib/Language/Elna/TypeAnalysis.hs

187 lines
8.2 KiB
Haskell
Raw Normal View History

2024-08-06 17:02:18 +02:00
module Language.Elna.TypeAnalysis
2024-09-08 02:08:13 +02:00
( typeAnalysis
, -- Error(..)
2024-08-06 17:02:18 +02:00
) where
2024-09-08 02:08:13 +02:00
import qualified Language.Elna.AST as AST
import Language.Elna.SymbolTable ({-Info(..), ParameterInfo(..), -}SymbolTable)
typeAnalysis :: SymbolTable -> AST.Program -> () -- Maybe Error
typeAnalysis _globalTable = const () {- either Just (const Nothing)
. runExcept
. flip runReaderT globalTable
. runTypeAnalysis
. program -}
{-
2024-08-07 22:47:35 +02:00
import Control.Applicative (Alternative(..))
import Control.Monad.Trans.Except (Except, runExcept, throwE)
2024-08-09 20:16:41 +02:00
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT, withReaderT, ask)
2024-08-08 19:11:24 +02:00
import qualified Data.Vector as Vector
2024-08-07 22:47:35 +02:00
import Language.Elna.Location (Identifier(..))
import qualified Language.Elna.SymbolTable as SymbolTable
import Language.Elna.Types (Type(..), booleanType, intType)
import Control.Monad.Trans.Class (MonadTrans(..))
2024-08-08 19:11:24 +02:00
import Control.Monad (unless, when)
import Data.Foldable (traverse_)
2024-08-06 17:02:18 +02:00
2024-08-07 22:47:35 +02:00
data Error
= ArithmeticExpressionError Type
| ComparisonExpressionError Type Type
| UnexpectedVariableInfoError Info
2024-08-08 19:11:24 +02:00
| UnexpectedProcedureInfoError Info
2024-08-07 22:47:35 +02:00
| UndefinedSymbolError Identifier
2024-08-08 19:11:24 +02:00
| InvalidConditionTypeError Type
| InvalidAssignmentError Type
| ExpectedLvalueError AST.Expression
| ParameterCountMismatchError Int Int
| ArgumentTypeMismatchError Type Type
2024-08-07 22:47:35 +02:00
| ArrayIndexError Type
| ArrayAccessError Type
2024-08-06 17:02:18 +02:00
deriving (Eq, Show)
newtype TypeAnalysis a = TypeAnalysis
{ runTypeAnalysis :: ReaderT SymbolTable (Except Error) a
}
instance Functor TypeAnalysis
where
fmap f (TypeAnalysis x) = TypeAnalysis $ f <$> x
instance Applicative TypeAnalysis
where
pure = TypeAnalysis . pure
(TypeAnalysis f) <*> (TypeAnalysis x) = TypeAnalysis $ f <*> x
instance Monad TypeAnalysis
where
(TypeAnalysis x) >>= f = TypeAnalysis $ x >>= (runTypeAnalysis . f)
program :: AST.Program -> TypeAnalysis ()
2024-08-09 20:16:41 +02:00
program (AST.Program declarations) = traverse_ declaration declarations
declaration :: AST.Declaration -> TypeAnalysis ()
declaration (AST.ProcedureDefinition procedureName _ _ body) = do
globalTable <- TypeAnalysis ask
case SymbolTable.lookup procedureName globalTable of
Just (ProcedureInfo localTable _) -> TypeAnalysis
$ withReaderT (const localTable)
$ runTypeAnalysis
$ traverse_ (statement globalTable) body
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedProcedureInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
declaration _ = pure ()
2024-08-07 22:47:35 +02:00
2024-08-08 19:11:24 +02:00
statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
statement globalTable = \case
AST.EmptyStatement -> pure ()
AST.AssignmentStatement lhs rhs -> do
2024-08-15 20:13:56 +02:00
lhsType <- variableAccess globalTable lhs
2024-08-08 19:11:24 +02:00
rhsType <- expression globalTable rhs
unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
2024-08-15 20:13:56 +02:00
AST.IfStatement ifCondition ifStatement elseStatement -> do
conditionType <- condition globalTable ifCondition
2024-08-08 19:11:24 +02:00
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable ifStatement
maybe (pure ()) (statement globalTable) elseStatement
2024-08-15 20:13:56 +02:00
AST.WhileStatement whileCondition whileStatement -> do
conditionType <- condition globalTable whileCondition
2024-08-08 19:11:24 +02:00
unless (conditionType == booleanType)
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
statement globalTable whileStatement
AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
AST.CallStatement procedureName arguments ->
case SymbolTable.lookup procedureName globalTable of
Just (ProcedureInfo _ parameters)
| parametersLength <- Vector.length parameters
, argumentsLength <- length arguments
, Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE
$ ParameterCountMismatchError parametersLength argumentsLength
| otherwise -> traverse_ (uncurry checkArgument)
$ Vector.zip parameters (Vector.fromList arguments)
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedVariableInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError procedureName
where
checkArgument ParameterInfo{..} argument = do
argumentType <- expression globalTable argument
unless (argumentType == type')
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
when (isReferenceParameter && not (isLvalue argument))
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
isLvalue (AST.VariableExpression _) = True
isLvalue _ = False
2024-08-15 20:13:56 +02:00
variableAccess :: SymbolTable -> AST.VariableAccess -> TypeAnalysis Type
variableAccess globalTable (AST.VariableAccess identifier) = do
localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier
case localLookup <|> SymbolTable.lookup identifier globalTable of
Just (VariableInfo _ variableType) -> pure variableType
Just anotherInfo -> TypeAnalysis $ lift $ throwE
$ UnexpectedVariableInfoError anotherInfo
Nothing -> TypeAnalysis $ lift $ throwE
$ UndefinedSymbolError identifier
variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = do
arrayType <- variableAccess globalTable arrayExpression
indexType <- expression globalTable indexExpression
unless (indexType == intType)
$ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType
case arrayType of
ArrayType _ baseType -> pure baseType
nonArrayType -> TypeAnalysis $ lift $ throwE
$ ArrayAccessError nonArrayType
2024-08-07 22:47:35 +02:00
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
expression globalTable = \case
2024-08-15 20:13:56 +02:00
AST.VariableExpression variableExpression -> do
variableAccess globalTable variableExpression
2024-08-07 22:47:35 +02:00
AST.LiteralExpression literal' -> literal literal'
AST.NegationExpression negation -> do
operandType <- expression globalTable negation
if operandType == intType
then pure intType
else TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError operandType
AST.SumExpression lhs rhs -> arithmeticExpression lhs rhs
AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs
AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs
AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs
where
arithmeticExpression lhs rhs = do
lhsType <- expression globalTable lhs
unless (lhsType == intType)
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError lhsType
rhsType <- expression globalTable rhs
unless (rhsType == intType)
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType
pure intType
2024-08-15 20:13:56 +02:00
condition :: SymbolTable -> AST.Condition -> TypeAnalysis Type
condition globalTable = \case
AST.EqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.NonEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterCondition lhs rhs -> comparisonExpression lhs rhs
AST.LessOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
AST.GreaterOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
where
2024-08-07 22:47:35 +02:00
comparisonExpression lhs rhs = do
lhsType <- expression globalTable lhs
rhsType <- expression globalTable rhs
2024-08-15 20:13:56 +02:00
if lhsType == intType && rhsType == intType
2024-08-07 22:47:35 +02:00
then pure booleanType
else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType
literal :: AST.Literal -> TypeAnalysis Type
literal (AST.IntegerLiteral _) = pure intType
literal (AST.HexadecimalLiteral _) = pure intType
literal (AST.CharacterLiteral _) = pure intType
literal (AST.BooleanLiteral _) = pure booleanType
2024-09-08 02:08:13 +02:00
-}