187 lines
8.2 KiB
Haskell
187 lines
8.2 KiB
Haskell
module Language.Elna.TypeAnalysis
|
|
( typeAnalysis
|
|
, -- Error(..)
|
|
) where
|
|
|
|
import qualified Language.Elna.AST as AST
|
|
import Language.Elna.SymbolTable ({-Info(..), ParameterInfo(..), -}SymbolTable)
|
|
|
|
typeAnalysis :: SymbolTable -> AST.Program -> () -- Maybe Error
|
|
typeAnalysis _globalTable = const () {- either Just (const Nothing)
|
|
. runExcept
|
|
. flip runReaderT globalTable
|
|
. runTypeAnalysis
|
|
. program -}
|
|
|
|
{-
|
|
import Control.Applicative (Alternative(..))
|
|
import Control.Monad.Trans.Except (Except, runExcept, throwE)
|
|
import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT, withReaderT, ask)
|
|
import qualified Data.Vector as Vector
|
|
import Language.Elna.Location (Identifier(..))
|
|
import qualified Language.Elna.SymbolTable as SymbolTable
|
|
import Language.Elna.Types (Type(..), booleanType, intType)
|
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
|
import Control.Monad (unless, when)
|
|
import Data.Foldable (traverse_)
|
|
|
|
data Error
|
|
= ArithmeticExpressionError Type
|
|
| ComparisonExpressionError Type Type
|
|
| UnexpectedVariableInfoError Info
|
|
| UnexpectedProcedureInfoError Info
|
|
| UndefinedSymbolError Identifier
|
|
| InvalidConditionTypeError Type
|
|
| InvalidAssignmentError Type
|
|
| ExpectedLvalueError AST.Expression
|
|
| ParameterCountMismatchError Int Int
|
|
| ArgumentTypeMismatchError Type Type
|
|
| ArrayIndexError Type
|
|
| ArrayAccessError Type
|
|
deriving (Eq, Show)
|
|
|
|
newtype TypeAnalysis a = TypeAnalysis
|
|
{ runTypeAnalysis :: ReaderT SymbolTable (Except Error) a
|
|
}
|
|
|
|
instance Functor TypeAnalysis
|
|
where
|
|
fmap f (TypeAnalysis x) = TypeAnalysis $ f <$> x
|
|
|
|
instance Applicative TypeAnalysis
|
|
where
|
|
pure = TypeAnalysis . pure
|
|
(TypeAnalysis f) <*> (TypeAnalysis x) = TypeAnalysis $ f <*> x
|
|
|
|
instance Monad TypeAnalysis
|
|
where
|
|
(TypeAnalysis x) >>= f = TypeAnalysis $ x >>= (runTypeAnalysis . f)
|
|
|
|
program :: AST.Program -> TypeAnalysis ()
|
|
program (AST.Program declarations) = traverse_ declaration declarations
|
|
|
|
declaration :: AST.Declaration -> TypeAnalysis ()
|
|
declaration (AST.ProcedureDefinition procedureName _ _ body) = do
|
|
globalTable <- TypeAnalysis ask
|
|
case SymbolTable.lookup procedureName globalTable of
|
|
Just (ProcedureInfo localTable _) -> TypeAnalysis
|
|
$ withReaderT (const localTable)
|
|
$ runTypeAnalysis
|
|
$ traverse_ (statement globalTable) body
|
|
Just anotherInfo -> TypeAnalysis $ lift $ throwE
|
|
$ UnexpectedProcedureInfoError anotherInfo
|
|
Nothing -> TypeAnalysis $ lift $ throwE
|
|
$ UndefinedSymbolError procedureName
|
|
declaration _ = pure ()
|
|
|
|
statement :: SymbolTable -> AST.Statement -> TypeAnalysis ()
|
|
statement globalTable = \case
|
|
AST.EmptyStatement -> pure ()
|
|
AST.AssignmentStatement lhs rhs -> do
|
|
lhsType <- variableAccess globalTable lhs
|
|
rhsType <- expression globalTable rhs
|
|
unless (lhsType == intType)
|
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError lhsType
|
|
unless (rhsType == intType)
|
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError rhsType
|
|
AST.IfStatement ifCondition ifStatement elseStatement -> do
|
|
conditionType <- condition globalTable ifCondition
|
|
unless (conditionType == booleanType)
|
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
|
|
statement globalTable ifStatement
|
|
maybe (pure ()) (statement globalTable) elseStatement
|
|
AST.WhileStatement whileCondition whileStatement -> do
|
|
conditionType <- condition globalTable whileCondition
|
|
unless (conditionType == booleanType)
|
|
$ TypeAnalysis $ lift $ throwE $ InvalidConditionTypeError conditionType
|
|
statement globalTable whileStatement
|
|
AST.CompoundStatement statements -> traverse_ (statement globalTable) statements
|
|
AST.CallStatement procedureName arguments ->
|
|
case SymbolTable.lookup procedureName globalTable of
|
|
Just (ProcedureInfo _ parameters)
|
|
| parametersLength <- Vector.length parameters
|
|
, argumentsLength <- length arguments
|
|
, Vector.length parameters /= length arguments -> TypeAnalysis $ lift $ throwE
|
|
$ ParameterCountMismatchError parametersLength argumentsLength
|
|
| otherwise -> traverse_ (uncurry checkArgument)
|
|
$ Vector.zip parameters (Vector.fromList arguments)
|
|
Just anotherInfo -> TypeAnalysis $ lift $ throwE
|
|
$ UnexpectedVariableInfoError anotherInfo
|
|
Nothing -> TypeAnalysis $ lift $ throwE
|
|
$ UndefinedSymbolError procedureName
|
|
where
|
|
checkArgument ParameterInfo{..} argument = do
|
|
argumentType <- expression globalTable argument
|
|
unless (argumentType == type')
|
|
$ TypeAnalysis $ lift $ throwE $ ArgumentTypeMismatchError type' argumentType
|
|
when (isReferenceParameter && not (isLvalue argument))
|
|
$ TypeAnalysis $ lift $ throwE $ ExpectedLvalueError argument
|
|
isLvalue (AST.VariableExpression _) = True
|
|
isLvalue _ = False
|
|
|
|
variableAccess :: SymbolTable -> AST.VariableAccess -> TypeAnalysis Type
|
|
variableAccess globalTable (AST.VariableAccess identifier) = do
|
|
localLookup <- TypeAnalysis $ asks $ SymbolTable.lookup identifier
|
|
case localLookup <|> SymbolTable.lookup identifier globalTable of
|
|
Just (VariableInfo _ variableType) -> pure variableType
|
|
Just anotherInfo -> TypeAnalysis $ lift $ throwE
|
|
$ UnexpectedVariableInfoError anotherInfo
|
|
Nothing -> TypeAnalysis $ lift $ throwE
|
|
$ UndefinedSymbolError identifier
|
|
variableAccess globalTable (AST.ArrayAccess arrayExpression indexExpression) = do
|
|
arrayType <- variableAccess globalTable arrayExpression
|
|
indexType <- expression globalTable indexExpression
|
|
unless (indexType == intType)
|
|
$ TypeAnalysis $ lift $ throwE $ ArrayIndexError indexType
|
|
case arrayType of
|
|
ArrayType _ baseType -> pure baseType
|
|
nonArrayType -> TypeAnalysis $ lift $ throwE
|
|
$ ArrayAccessError nonArrayType
|
|
|
|
expression :: SymbolTable -> AST.Expression -> TypeAnalysis Type
|
|
expression globalTable = \case
|
|
AST.VariableExpression variableExpression -> do
|
|
variableAccess globalTable variableExpression
|
|
AST.LiteralExpression literal' -> literal literal'
|
|
AST.NegationExpression negation -> do
|
|
operandType <- expression globalTable negation
|
|
if operandType == intType
|
|
then pure intType
|
|
else TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError operandType
|
|
AST.SumExpression lhs rhs -> arithmeticExpression lhs rhs
|
|
AST.SubtractionExpression lhs rhs -> arithmeticExpression lhs rhs
|
|
AST.ProductExpression lhs rhs -> arithmeticExpression lhs rhs
|
|
AST.DivisionExpression lhs rhs -> arithmeticExpression lhs rhs
|
|
where
|
|
arithmeticExpression lhs rhs = do
|
|
lhsType <- expression globalTable lhs
|
|
unless (lhsType == intType)
|
|
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError lhsType
|
|
rhsType <- expression globalTable rhs
|
|
unless (rhsType == intType)
|
|
$ TypeAnalysis $ lift $ throwE $ ArithmeticExpressionError rhsType
|
|
pure intType
|
|
|
|
condition :: SymbolTable -> AST.Condition -> TypeAnalysis Type
|
|
condition globalTable = \case
|
|
AST.EqualCondition lhs rhs -> comparisonExpression lhs rhs
|
|
AST.NonEqualCondition lhs rhs -> comparisonExpression lhs rhs
|
|
AST.LessCondition lhs rhs -> comparisonExpression lhs rhs
|
|
AST.GreaterCondition lhs rhs -> comparisonExpression lhs rhs
|
|
AST.LessOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
|
|
AST.GreaterOrEqualCondition lhs rhs -> comparisonExpression lhs rhs
|
|
where
|
|
comparisonExpression lhs rhs = do
|
|
lhsType <- expression globalTable lhs
|
|
rhsType <- expression globalTable rhs
|
|
if lhsType == intType && rhsType == intType
|
|
then pure booleanType
|
|
else TypeAnalysis $ lift $ throwE $ ComparisonExpressionError lhsType rhsType
|
|
|
|
literal :: AST.Literal -> TypeAnalysis Type
|
|
literal (AST.IntegerLiteral _) = pure intType
|
|
literal (AST.HexadecimalLiteral _) = pure intType
|
|
literal (AST.CharacterLiteral _) = pure intType
|
|
literal (AST.BooleanLiteral _) = pure booleanType
|
|
-}
|