306 lines
12 KiB
Haskell
306 lines
12 KiB
Haskell
module Language.Elna.Intermediate
|
|
( Label(..)
|
|
, Operand(..)
|
|
, Quadruple(..)
|
|
, Variable(..)
|
|
, intermediate
|
|
) where
|
|
|
|
import Control.Monad.Trans.State (State, runState, gets, modify')
|
|
import Data.Bifunctor (Bifunctor(..))
|
|
import Data.Int (Int32)
|
|
import Data.HashMap.Strict (HashMap)
|
|
import qualified Data.HashMap.Strict as HashMap
|
|
import Data.Text (Text)
|
|
import Data.Word (Word32)
|
|
import Data.Vector (Vector)
|
|
import qualified Data.Vector as Vector
|
|
import qualified Language.Elna.AST as AST
|
|
import Language.Elna.Types (Type(..))
|
|
import Language.Elna.SymbolTable (SymbolTable, Info(..))
|
|
import qualified Language.Elna.SymbolTable as SymbolTable
|
|
import Data.Foldable (Foldable(..), foldrM)
|
|
import GHC.Records (HasField(..))
|
|
import qualified Data.Text.Lazy.Builder.Int as Text.Builder
|
|
import qualified Data.Text.Lazy.Builder as Text.Builder
|
|
import qualified Data.Text.Lazy as Text.Lazy
|
|
|
|
data Operand
|
|
= VariableOperand Variable
|
|
| IntOperand Int32
|
|
deriving (Eq, Show)
|
|
|
|
newtype Label = Label Text
|
|
deriving (Eq, Show)
|
|
|
|
data Variable = Variable Text | TempVariable
|
|
deriving (Eq, Show)
|
|
|
|
newtype Generator = Generator
|
|
{ labelCounter :: Int32
|
|
} deriving (Eq, Show)
|
|
|
|
instance Semigroup Generator
|
|
where
|
|
lhs <> rhs = Generator
|
|
{ labelCounter = getField @"labelCounter" lhs + getField @"labelCounter" rhs
|
|
}
|
|
|
|
instance Monoid Generator
|
|
where
|
|
mempty = Generator
|
|
{ labelCounter = 0
|
|
}
|
|
|
|
newtype Intermediate a = Intermediate
|
|
{ runIntermediate :: State Generator a
|
|
}
|
|
|
|
instance Functor Intermediate
|
|
where
|
|
fmap f (Intermediate x) = Intermediate $ f <$> x
|
|
|
|
instance Applicative Intermediate
|
|
where
|
|
pure = Intermediate . pure
|
|
(Intermediate f) <*> (Intermediate x) = Intermediate $ f <*> x
|
|
|
|
instance Monad Intermediate
|
|
where
|
|
(Intermediate x) >>= f = Intermediate $ x >>= (runIntermediate . f)
|
|
|
|
data Quadruple
|
|
= StartQuadruple
|
|
| GoToQuadruple Label
|
|
| AssignQuadruple Operand Variable
|
|
| ArrayQuadruple Variable Operand Variable
|
|
| ArrayAssignQuadruple Operand Operand Variable
|
|
| AddQuadruple Operand Operand Variable
|
|
| SubtractionQuadruple Operand Operand Variable
|
|
| ProductQuadruple Operand Operand Variable
|
|
| DivisionQuadruple Operand Operand Variable
|
|
| NegationQuadruple Operand Variable
|
|
| EqualQuadruple Operand Operand Label
|
|
| NonEqualQuadruple Operand Operand Label
|
|
| LessQuadruple Operand Operand Label
|
|
| GreaterQuadruple Operand Operand Label
|
|
| LessOrEqualQuadruple Operand Operand Label
|
|
| GreaterOrEqualQuadruple Operand Operand Label
|
|
| LabelQuadruple Label
|
|
| ParameterQuadruple Operand
|
|
| CallQuadruple Variable Word32
|
|
| StopQuadruple
|
|
deriving (Eq, Show)
|
|
|
|
createLabel :: Intermediate Label
|
|
createLabel = do
|
|
currentCounter <- Intermediate $ gets labelCounter
|
|
Intermediate $ modify' modifier
|
|
pure
|
|
$ Label
|
|
$ Text.Lazy.toStrict
|
|
$ Text.Builder.toLazyText
|
|
$ Text.Builder.decimal currentCounter
|
|
where
|
|
modifier generator = generator
|
|
{ labelCounter = getField @"labelCounter" generator + 1
|
|
}
|
|
|
|
intermediate :: SymbolTable -> AST.Program -> HashMap AST.Identifier (Vector Quadruple)
|
|
intermediate globalTable
|
|
= fst
|
|
. flip runState mempty
|
|
. runIntermediate
|
|
. program globalTable
|
|
|
|
program
|
|
:: SymbolTable
|
|
-> AST.Program
|
|
-> Intermediate (HashMap AST.Identifier (Vector Quadruple))
|
|
program globalTable (AST.Program declarations) =
|
|
foldrM go HashMap.empty declarations
|
|
where
|
|
go (AST.TypeDefinition _ _) accumulator = pure accumulator
|
|
go (AST.ProcedureDefinition procedureName _ _ statements) accumulator = do
|
|
translatedStatements <- Vector.cons StartQuadruple
|
|
. flip Vector.snoc StopQuadruple
|
|
. fold
|
|
<$> traverse (statement globalTable) statements
|
|
pure $ HashMap.insert procedureName translatedStatements accumulator
|
|
|
|
statement :: SymbolTable -> AST.Statement -> Intermediate (Vector Quadruple)
|
|
statement _ AST.EmptyStatement = pure mempty
|
|
statement localTable (AST.AssignmentStatement variableAccess' assignee) = do
|
|
(rhsOperand, rhsStatements) <- expression localTable assignee
|
|
let variableType' = variableType variableAccess' localTable
|
|
accessResult <- variableAccess localTable variableAccess' Nothing variableType' mempty
|
|
pure $ rhsStatements <> case accessResult of
|
|
(AST.Identifier identifier, Just accumulatedIndex, accumulatedStatements) ->
|
|
Vector.snoc accumulatedStatements
|
|
$ ArrayAssignQuadruple rhsOperand accumulatedIndex
|
|
$ Variable identifier
|
|
(AST.Identifier identifier, Nothing, accumulatedStatements) ->
|
|
Vector.snoc accumulatedStatements
|
|
$ AssignQuadruple rhsOperand
|
|
$ Variable identifier
|
|
statement localTable (AST.IfStatement ifCondition ifStatement elseStatement) = do
|
|
(conditionStatements, jumpConstructor) <- condition localTable ifCondition
|
|
ifLabel <- createLabel
|
|
endLabel <- createLabel
|
|
ifStatements <- statement localTable ifStatement
|
|
possibleElseStatements <- traverse (statement localTable) elseStatement
|
|
pure $ conditionStatements <> case possibleElseStatements of
|
|
Just elseStatements -> Vector.cons (jumpConstructor ifLabel) elseStatements
|
|
<> Vector.fromList [GoToQuadruple endLabel, LabelQuadruple ifLabel]
|
|
<> Vector.snoc ifStatements (LabelQuadruple endLabel)
|
|
Nothing -> Vector.fromList [jumpConstructor ifLabel, GoToQuadruple endLabel, LabelQuadruple ifLabel]
|
|
<> Vector.snoc ifStatements (LabelQuadruple endLabel)
|
|
statement localTable (AST.WhileStatement whileCondition whileStatement) = do
|
|
(conditionStatements, jumpConstructor) <- condition localTable whileCondition
|
|
startLabel <- createLabel
|
|
endLabel <- createLabel
|
|
conditionLabel <- createLabel
|
|
whileStatements <- statement localTable whileStatement
|
|
pure $ Vector.fromList [LabelQuadruple conditionLabel]
|
|
<> conditionStatements
|
|
<> Vector.fromList [jumpConstructor startLabel, GoToQuadruple endLabel, LabelQuadruple startLabel]
|
|
<> whileStatements
|
|
<> Vector.fromList [GoToQuadruple conditionLabel, LabelQuadruple endLabel]
|
|
statement localTable (AST.CallStatement (AST.Identifier callName) arguments) = do
|
|
visitedArguments <- traverse (expression localTable) arguments
|
|
let (parameterStatements, argumentStatements)
|
|
= bimap (Vector.fromList . fmap ParameterQuadruple) Vector.concat
|
|
$ unzip visitedArguments
|
|
in pure
|
|
$ Vector.snoc (argumentStatements <> parameterStatements)
|
|
$ CallQuadruple (Variable callName)
|
|
$ fromIntegral
|
|
$ Vector.length argumentStatements
|
|
statement localTable (AST.CompoundStatement statements) =
|
|
fold <$> traverse (statement localTable) statements
|
|
|
|
condition
|
|
:: SymbolTable
|
|
-> AST.Condition
|
|
-> Intermediate (Vector Quadruple, Label -> Quadruple)
|
|
condition localTable (AST.EqualCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure
|
|
( lhsStatements <> rhsStatements
|
|
, EqualQuadruple lhsOperand rhsOperand
|
|
)
|
|
condition localTable (AST.NonEqualCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure
|
|
( lhsStatements <> rhsStatements
|
|
, NonEqualQuadruple lhsOperand rhsOperand
|
|
)
|
|
condition localTable (AST.LessCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure (lhsStatements <> rhsStatements, LessQuadruple lhsOperand rhsOperand)
|
|
condition localTable (AST.GreaterCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure
|
|
( lhsStatements <> rhsStatements
|
|
, GreaterQuadruple lhsOperand rhsOperand
|
|
)
|
|
condition localTable (AST.LessOrEqualCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure
|
|
( lhsStatements <> rhsStatements
|
|
, LessOrEqualQuadruple lhsOperand rhsOperand
|
|
)
|
|
condition localTable (AST.GreaterOrEqualCondition lhs rhs) = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
pure
|
|
( lhsStatements <> rhsStatements
|
|
, GreaterOrEqualQuadruple lhsOperand rhsOperand
|
|
)
|
|
|
|
variableAccess
|
|
:: SymbolTable
|
|
-> AST.VariableAccess
|
|
-> Maybe Operand
|
|
-> Type
|
|
-> Vector Quadruple
|
|
-> Intermediate (AST.Identifier, Maybe Operand, Vector Quadruple)
|
|
variableAccess _ (AST.VariableAccess identifier) accumulatedIndex _ accumulatedStatements =
|
|
pure (identifier, accumulatedIndex, accumulatedStatements)
|
|
variableAccess localTable (AST.ArrayAccess access1 index1) Nothing (ArrayType _ baseType) _ = do
|
|
(indexPlace, statements) <- expression localTable index1
|
|
variableAccess localTable access1 (Just indexPlace) baseType statements
|
|
variableAccess localTable (AST.ArrayAccess arrayAccess' arrayIndex) (Just baseIndex) (ArrayType arraySize baseType) statements = do
|
|
(indexPlace, statements') <- expression localTable arrayIndex
|
|
let resultVariable = TempVariable
|
|
resultOperand = VariableOperand resultVariable
|
|
indexCalculation = Vector.fromList
|
|
[ ProductQuadruple (IntOperand $ fromIntegral arraySize) baseIndex resultVariable
|
|
, AddQuadruple indexPlace resultOperand resultVariable
|
|
]
|
|
in variableAccess localTable arrayAccess' (Just resultOperand) baseType
|
|
$ statements <> indexCalculation <> statements'
|
|
variableAccess _ _ _ _ _ = error "Array access operator doesn't match the type."
|
|
|
|
variableType :: AST.VariableAccess -> SymbolTable -> Type
|
|
variableType (AST.VariableAccess identifier) symbolTable
|
|
| Just (TypeInfo type') <- SymbolTable.lookup identifier symbolTable = type'
|
|
| otherwise = error "Undefined type."
|
|
variableType (AST.ArrayAccess arrayAccess' _) symbolTable =
|
|
variableType arrayAccess' symbolTable
|
|
|
|
expression :: SymbolTable -> AST.Expression -> Intermediate (Operand, Vector Quadruple)
|
|
expression localTable = \case
|
|
(AST.VariableExpression variableExpression) -> do
|
|
let variableType' = variableType variableExpression localTable
|
|
variableAccess' <- variableAccess localTable variableExpression Nothing variableType' mempty
|
|
case variableAccess' of
|
|
(AST.Identifier identifier, Nothing, statements) ->
|
|
pure (VariableOperand (Variable identifier), statements)
|
|
(AST.Identifier identifier, Just operand, statements) ->
|
|
let arrayAddress = TempVariable
|
|
arrayStatement = ArrayQuadruple (Variable identifier) operand arrayAddress
|
|
in pure
|
|
( VariableOperand arrayAddress
|
|
, Vector.snoc statements arrayStatement
|
|
)
|
|
(AST.LiteralExpression literal') -> pure (literal literal', mempty)
|
|
(AST.NegationExpression negation) -> do
|
|
(operand, statements) <- expression localTable negation
|
|
let tempVariable = TempVariable
|
|
negationQuadruple = NegationQuadruple operand tempVariable
|
|
in pure
|
|
( VariableOperand tempVariable
|
|
, Vector.snoc statements negationQuadruple
|
|
)
|
|
(AST.SumExpression lhs rhs) -> binaryExpression AddQuadruple lhs rhs
|
|
(AST.SubtractionExpression lhs rhs) ->
|
|
binaryExpression SubtractionQuadruple lhs rhs
|
|
(AST.ProductExpression lhs rhs) ->
|
|
binaryExpression ProductQuadruple lhs rhs
|
|
(AST.DivisionExpression lhs rhs) ->
|
|
binaryExpression DivisionQuadruple lhs rhs
|
|
where
|
|
binaryExpression f lhs rhs = do
|
|
(lhsOperand, lhsStatements) <- expression localTable lhs
|
|
(rhsOperand, rhsStatements) <- expression localTable rhs
|
|
let tempVariable = TempVariable
|
|
newQuadruple = f lhsOperand rhsOperand tempVariable
|
|
in pure
|
|
( VariableOperand tempVariable
|
|
, Vector.snoc (lhsStatements <> rhsStatements) newQuadruple
|
|
)
|
|
|
|
literal :: AST.Literal -> Operand
|
|
literal (AST.IntegerLiteral integer) = IntOperand integer
|
|
literal (AST.HexadecimalLiteral integer) = IntOperand integer
|
|
literal (AST.CharacterLiteral character) = IntOperand $ fromIntegral character
|
|
literal (AST.BooleanLiteral boolean)
|
|
| boolean = IntOperand 1
|
|
| otherwise = IntOperand 0
|