From 1c996b3c8bb290d17e4d0dcdf809c8458866bb12 Mon Sep 17 00:00:00 2001
From: Eugen Wissner <belka@caraus.de>
Date: Wed, 4 Dec 2024 16:11:06 +0100
Subject: [PATCH] Make IR for array access

---
 lib/Language/Elna/Backend/Allocator.hs        | 127 +++++++++---------
 lib/Language/Elna/Backend/Intermediate.hs     |   2 +-
 lib/Language/Elna/Frontend/SymbolTable.hs     |   4 +
 lib/Language/Elna/Glue.hs                     |  34 ++---
 lib/Language/Elna/RiscV/CodeGenerator.hs      |  32 +++++
 src/Main.hs                                   |   2 +-
 .../expectations/array_element_assignment.txt |   1 +
 tests/expectations/print_array_element.txt    |   2 +
 tests/vm/array_element_assignment.elna        |   3 +
 tests/vm/print_array_element.elna             |   8 ++
 10 files changed, 138 insertions(+), 77 deletions(-)
 create mode 100644 tests/expectations/print_array_element.txt
 create mode 100644 tests/vm/print_array_element.elna

diff --git a/lib/Language/Elna/Backend/Allocator.hs b/lib/Language/Elna/Backend/Allocator.hs
index acdf3e5..df50d6e 100644
--- a/lib/Language/Elna/Backend/Allocator.hs
+++ b/lib/Language/Elna/Backend/Allocator.hs
@@ -5,6 +5,7 @@ module Language.Elna.Backend.Allocator
     ) where
 
 import Data.HashMap.Strict (HashMap)
+import qualified Data.HashMap.Strict as HashMap
 import Data.Int (Int32)
 import Data.Word (Word32)
 import Data.Vector (Vector)
@@ -16,11 +17,13 @@ import Language.Elna.Backend.Intermediate
     )
 import Language.Elna.Location (Identifier(..))
 import Control.Monad.Trans.Reader (ReaderT, asks, runReaderT)
-import Control.Monad.Trans.State (State, runState, modify')
+import Control.Monad.Trans.State (State, runState)
 import GHC.Records (HasField(..))
 import Control.Monad.Trans.Class (MonadTrans(..))
 import Control.Monad.Trans.Except (ExceptT(..), runExceptT, throwE)
 import Data.List ((!?))
+import Language.Elna.Frontend.SymbolTable (Info(..), SymbolTable)
+import qualified Language.Elna.Frontend.SymbolTable as SymbolTable
 
 data Store r
     = RegisterStore r
@@ -38,7 +41,7 @@ newtype MachineConfiguration r = MachineConfiguration
     }
 
 newtype MachineState = MachineState
-    { stackSize :: Word32
+    { symbolTable :: SymbolTable
     } deriving (Eq, Show)
 
 newtype Allocator r a = Allocator
@@ -61,87 +64,92 @@ instance forall r. Monad (Allocator r)
 allocate
     :: forall r
     . MachineConfiguration r
+    -> SymbolTable
     -> HashMap Identifier (Vector (Quadruple Variable))
     -> Either AllocationError (HashMap Identifier (ProcedureQuadruples (Store r)))
-allocate machineConfiguration = traverse function
+allocate machineConfiguration globalTable = HashMap.traverseWithKey function
   where
-    run = flip runState (MachineState{ stackSize = 0 })
+    run localTable = flip runState (MachineState{ symbolTable = localTable })
         . flip runReaderT machineConfiguration
         . runExceptT
         . runAllocator
         . mapM quadruple
-    function :: Vector (Quadruple Variable) -> Either AllocationError (ProcedureQuadruples (Store r))
-    function quadruples' =
-        let (result, lastState) = run quadruples'
+    function :: Identifier -> Vector (Quadruple Variable) -> Either AllocationError (ProcedureQuadruples (Store r))
+    function identifier quadruples' =
+        let Just (ProcedureInfo localTable _) = SymbolTable.lookup identifier globalTable
+            (result, lastState) = run localTable quadruples'
          in makeResult lastState <$> result
-    makeResult MachineState{ stackSize } result = ProcedureQuadruples
+    makeResult MachineState{ symbolTable } result = ProcedureQuadruples
         { quadruples = result
-        , stackSize = stackSize
+        , stackSize = fromIntegral $ SymbolTable.size symbolTable * 4
         }
 
 quadruple :: Quadruple Variable -> Allocator r (Quadruple (Store r))
 quadruple = \case
     StartQuadruple -> pure StartQuadruple
     StopQuadruple -> pure StopQuadruple
-    ParameterQuadruple operand1 -> do
-        operand1' <- operand operand1
-        pure $ ParameterQuadruple operand1'
+    ParameterQuadruple operand1 -> ParameterQuadruple
+        <$> operand operand1
     CallQuadruple name count -> pure $ CallQuadruple name count
-    AddQuadruple operand1 operand2 variable -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        AddQuadruple operand1' operand2' <$> storeVariable variable
-    SubtractionQuadruple operand1 operand2 variable -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        SubtractionQuadruple operand1' operand2' <$> storeVariable variable
-    NegationQuadruple operand1 variable -> do
-        operand1' <- operand operand1
-        NegationQuadruple operand1' <$> storeVariable variable
-    ProductQuadruple operand1 operand2 variable -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        ProductQuadruple operand1' operand2' <$> storeVariable variable
-    DivisionQuadruple operand1 operand2 variable -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        DivisionQuadruple operand1' operand2' <$> storeVariable variable
+    AddQuadruple operand1 operand2 variable -> AddQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> storeVariable variable
+    SubtractionQuadruple operand1 operand2 variable -> SubtractionQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> storeVariable variable
+    NegationQuadruple operand1 variable -> NegationQuadruple
+        <$> operand operand1
+        <*> storeVariable variable
+    ProductQuadruple operand1 operand2 variable -> ProductQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> storeVariable variable
+    DivisionQuadruple operand1 operand2 variable -> DivisionQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> storeVariable variable
     LabelQuadruple label -> pure $ LabelQuadruple label
     GoToQuadruple label -> pure $ GoToQuadruple label
-    EqualQuadruple operand1 operand2 goToLabel -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        pure $ EqualQuadruple operand1' operand2' goToLabel
-    NonEqualQuadruple operand1 operand2 goToLabel -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        pure $ NonEqualQuadruple operand1' operand2' goToLabel
-    LessQuadruple operand1 operand2 goToLabel -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        pure $ LessQuadruple operand1' operand2' goToLabel
+    EqualQuadruple operand1 operand2 goToLabel -> EqualQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> pure goToLabel
+    NonEqualQuadruple operand1 operand2 goToLabel -> NonEqualQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> pure goToLabel
+    LessQuadruple operand1 operand2 goToLabel -> LessQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> pure goToLabel
     GreaterQuadruple operand1 operand2 goToLabel -> do
         operand1' <- operand operand1
         operand2' <- operand operand2
         pure $ GreaterQuadruple operand1' operand2' goToLabel
-    LessOrEqualQuadruple operand1 operand2 goToLabel -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        pure $ LessOrEqualQuadruple operand1' operand2' goToLabel
-    GreaterOrEqualQuadruple operand1 operand2 goToLabel -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        pure $ GreaterOrEqualQuadruple operand1' operand2' goToLabel
-    AssignQuadruple operand1 variable -> do
-        operand1' <- operand operand1
-        AssignQuadruple operand1' <$> storeVariable variable
-    ArrayAssignQuadruple operand1 operand2 variable -> do
-        operand1' <- operand operand1
-        operand2' <- operand operand2
-        ArrayAssignQuadruple operand1' operand2' <$> storeVariable variable
+    LessOrEqualQuadruple operand1 operand2 goToLabel -> LessOrEqualQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> pure goToLabel
+    GreaterOrEqualQuadruple operand1 operand2 goToLabel -> GreaterOrEqualQuadruple
+        <$> operand operand1
+        <*> operand operand2
+        <*> pure goToLabel
+    AssignQuadruple operand1 variable -> AssignQuadruple 
+        <$> operand operand1
+        <*> storeVariable variable
+    ArrayAssignQuadruple operand1 operand2 variable -> ArrayAssignQuadruple 
+        <$> operand operand1
+        <*> operand operand2
+        <*> storeVariable variable
+    ArrayQuadruple variable1 operand1 variable2 -> ArrayQuadruple 
+        <$> storeVariable variable1
+        <*> operand operand1
+        <*> storeVariable variable2
 
 operand :: Operand Variable -> Allocator r (Operand (Store r))
-operand (IntOperand x) = pure $ IntOperand x
+operand (IntOperand literalOperand) = pure $ IntOperand literalOperand
 operand (VariableOperand variableOperand) =
     VariableOperand <$> storeVariable variableOperand
 
@@ -152,7 +160,6 @@ storeVariable (TempVariable index) = do
         $ temporaryRegisters' !? fromIntegral index
 storeVariable (LocalVariable index) = do
     temporaryRegisters' <- Allocator $ lift $ asks $ getField @"temporaryRegisters"
-    Allocator $ lift $ lift $ modify' $ MachineState . (+ 4) . getField @"stackSize"
     maybe (Allocator $ throwE AllocationError) (pure . StackStore (fromIntegral (succ index) * (-4)))
         $ temporaryRegisters' !? pred (length temporaryRegisters' - fromIntegral index)
 storeVariable (ParameterVariable index) = do
diff --git a/lib/Language/Elna/Backend/Intermediate.hs b/lib/Language/Elna/Backend/Intermediate.hs
index dcf0ede..bb0ae7e 100644
--- a/lib/Language/Elna/Backend/Intermediate.hs
+++ b/lib/Language/Elna/Backend/Intermediate.hs
@@ -50,7 +50,7 @@ data Quadruple v
     | DivisionQuadruple (Operand v) (Operand v) v
     | GoToQuadruple Label
     | AssignQuadruple (Operand v) v
-    {-| ArrayQuadruple Variable Operand Variable -}
+    | ArrayQuadruple v (Operand v) v
     | ArrayAssignQuadruple (Operand v) (Operand v) v
     | LessOrEqualQuadruple (Operand v) (Operand v) Label
     | GreaterOrEqualQuadruple (Operand v) (Operand v) Label
diff --git a/lib/Language/Elna/Frontend/SymbolTable.hs b/lib/Language/Elna/Frontend/SymbolTable.hs
index e90a942..4333acc 100644
--- a/lib/Language/Elna/Frontend/SymbolTable.hs
+++ b/lib/Language/Elna/Frontend/SymbolTable.hs
@@ -9,6 +9,7 @@ module Language.Elna.Frontend.SymbolTable
     , lookup
     , member
     , scope
+    , size
     , toMap
     , update
     ) where
@@ -76,6 +77,9 @@ member :: Identifier -> SymbolTable -> Bool
 member identifier table =
     isJust $ lookup identifier table
 
+size :: SymbolTable -> Int
+size (SymbolTable _ map') = HashMap.size map'
+
 fromList :: [(Identifier, Info)] -> Either (NonEmpty Identifier) SymbolTable
 fromList elements
     | Just identifierDuplicates' <- identifierDuplicates =
diff --git a/lib/Language/Elna/Glue.hs b/lib/Language/Elna/Glue.hs
index 3cd46e3..ebb0f69 100644
--- a/lib/Language/Elna/Glue.hs
+++ b/lib/Language/Elna/Glue.hs
@@ -26,6 +26,7 @@ import Language.Elna.Frontend.SymbolTable (Info(..), SymbolTable)
 import qualified Language.Elna.Frontend.SymbolTable as SymbolTable
 import GHC.Records (HasField(..))
 import Language.Elna.Frontend.AST (Identifier(..))
+import Debug.Trace (traceShow)
 
 data Paste = Paste
     { temporaryCounter :: Word32
@@ -71,11 +72,12 @@ declaration
     :: SymbolTable
     -> AST.Declaration
     -> Glue (Maybe (AST.Identifier, Vector (Quadruple Variable)))
-declaration globalTable (AST.ProcedureDeclaration procedureName parameters variableDeclarations statements)
-    = Glue (modify' resetTemporaryCounter)
-    >> traverseWithIndex registerVariable variableDeclarations
-    >> traverseWithIndex registerParameter (reverse parameters)
-    >> nameQuadruplesTuple <$> traverse (statement globalTable) statements
+declaration globalTable (AST.ProcedureDeclaration procedureName parameters variableDeclarations statements) =
+    let Just (ProcedureInfo localTable _) = SymbolTable.lookup procedureName globalTable
+     in Glue (modify' resetTemporaryCounter)
+        >> traverseWithIndex registerVariable variableDeclarations
+        >> traverseWithIndex registerParameter (reverse parameters)
+        >> nameQuadruplesTuple <$> traverse (statement localTable) statements
   where
     traverseWithIndex f = traverse_ (uncurry f) . zip [0..]
     registerParameter index (AST.Parameter identifier _ _) =
@@ -129,11 +131,11 @@ statement localTable (AST.AssignmentStatement variableAccess' assignee) = do
     let variableType' = variableType variableAccess' localTable
     accessResult <- variableAccess localTable variableAccess' Nothing variableType' mempty
     lhsStatements <- case accessResult of
-            {-(AST.Identifier identifier, Just accumulatedIndex, accumulatedStatements) ->
-                Vector.snoc accumulatedStatements
-                    $ ArrayAssignQuadruple rhsOperand accumulatedIndex
-                    $ LocalVariable identifier -}
-            (identifier, _Nothing, accumulatedStatements)
+            (identifier, Just accumulatedIndex, accumulatedStatements)
+                -> Vector.snoc accumulatedStatements
+                . ArrayAssignQuadruple rhsOperand accumulatedIndex
+                <$> lookupLocal identifier
+            (identifier, Nothing, accumulatedStatements)
                 -> Vector.snoc accumulatedStatements
                 . AssignQuadruple rhsOperand
                 <$> lookupLocal identifier
@@ -251,7 +253,8 @@ variableAccess _ _ _ _ _ = error "Array access operator doesn't match the type."
 variableType :: AST.VariableAccess -> SymbolTable -> Type
 variableType (AST.VariableAccess identifier) symbolTable
     | Just (TypeInfo type') <- SymbolTable.lookup identifier symbolTable = type'
-    | otherwise = error "Undefined type."
+    | Just (VariableInfo _ type') <- SymbolTable.lookup identifier symbolTable = type'
+    | otherwise = traceShow identifier $ error "Undefined type."
 variableType (AST.ArrayAccess arrayAccess' _) symbolTable =
     variableType arrayAccess' symbolTable
 
@@ -277,16 +280,17 @@ expression localTable = \case
         let variableType' = variableType variableExpression localTable
         variableAccess' <- variableAccess localTable variableExpression Nothing variableType' mempty
         case variableAccess' of
-            (identifier, _Nothing, statements)
+            (identifier, Nothing, statements)
                 -> (, statements) . VariableOperand 
                 <$> lookupLocal identifier
-            {-(AST.Identifier identifier, Just operand, statements) -> do
+            (identifier, Just operand, statements) -> do
                 arrayAddress <- createTemporary
-                let arrayStatement = ArrayQuadruple (Variable identifier) operand arrayAddress
+                localVariable <- lookupLocal identifier
+                let arrayStatement = ArrayQuadruple localVariable operand arrayAddress
                 pure
                     ( VariableOperand arrayAddress
                     , Vector.snoc statements arrayStatement
-                    ) -}
+                    )
   where
     binaryExpression f lhs rhs = do
         (lhsOperand, lhsStatements) <- expression localTable lhs
diff --git a/lib/Language/Elna/RiscV/CodeGenerator.hs b/lib/Language/Elna/RiscV/CodeGenerator.hs
index 6e7a8cd..a1dcdbe 100644
--- a/lib/Language/Elna/RiscV/CodeGenerator.hs
+++ b/lib/Language/Elna/RiscV/CodeGenerator.hs
@@ -295,6 +295,38 @@ quadruple _ (ArrayAssignQuadruple assigneeOperand indexOperand store)
                 , storeInstruction
                 ]
         in (register, indexStatements <> statements)
+quadruple _ (ArrayQuadruple assigneeVariable indexOperand store) =
+    let (operandRegister1, statements1) = loadWithOffset assigneeVariable indexOperand
+        (storeRegister, storeStatements) = storeToStore store
+        instruction = Instruction
+            $ RiscV.BaseInstruction RiscV.OpImm
+            $ RiscV.I storeRegister RiscV.ADDI operandRegister1 0
+     in pure $ statements1 <> Vector.cons instruction storeStatements 
+  where
+    loadWithOffset :: RiscVStore -> Operand RiscVStore -> (RiscV.XRegister, Vector Statement)
+    loadWithOffset (RegisterStore register) _ = (register, mempty)
+    loadWithOffset (StackStore offset register) (IntOperand indexOffset) =
+        let loadInstruction = Instruction
+                $ RiscV.BaseInstruction RiscV.Load
+                $ RiscV.I register RiscV.LW RiscV.S0 (fromIntegral $ offset + indexOffset)
+        in (register, Vector.singleton loadInstruction)
+    loadWithOffset (StackStore offset register) (VariableOperand indexOffset) =
+        let baseRegisterInstruction = Instruction
+                $ RiscV.BaseInstruction RiscV.OpImm
+                $ RiscV.I immediateRegister RiscV.ADDI RiscV.S0 0
+            (indexRegister, indexStatements) = loadFromStore indexOffset
+            registerWithOffset = Instruction
+                $ RiscV.BaseInstruction RiscV.OpImm
+                $ RiscV.I immediateRegister RiscV.ADDI indexRegister 0
+            loadInstruction = Instruction
+                $ RiscV.BaseInstruction RiscV.Load
+                $ RiscV.I register RiscV.SW immediateRegister (fromIntegral offset)
+            statements = Vector.fromList
+                [ baseRegisterInstruction
+                , registerWithOffset
+                , loadInstruction
+                ]
+        in (register, indexStatements <> statements)
 
 unconditionalJal :: Label -> Statement
 unconditionalJal (Label goToLabel) = Instruction
diff --git a/src/Main.hs b/src/Main.hs
index 0d1387f..164f8e1 100644
--- a/src/Main.hs
+++ b/src/Main.hs
@@ -48,7 +48,7 @@ main = execParser commandLine >>= withCommandLine
         | otherwise =
             let makeObject = elfObject output . riscv32Elf . generateRiscV
              in either (printAndExit 6) makeObject
-                $ allocate riscVConfiguration
+                $ allocate riscVConfiguration symbolTable
                 $ glue symbolTable program
     printAndExit :: Show b => forall a. Int -> b -> IO a
     printAndExit failureCode e = print e >> exitWith (ExitFailure failureCode)
diff --git a/tests/expectations/array_element_assignment.txt b/tests/expectations/array_element_assignment.txt
index e69de29..7ed6ff8 100644
--- a/tests/expectations/array_element_assignment.txt
+++ b/tests/expectations/array_element_assignment.txt
@@ -0,0 +1 @@
+5
diff --git a/tests/expectations/print_array_element.txt b/tests/expectations/print_array_element.txt
new file mode 100644
index 0000000..b3172d1
--- /dev/null
+++ b/tests/expectations/print_array_element.txt
@@ -0,0 +1,2 @@
+5
+7
diff --git a/tests/vm/array_element_assignment.elna b/tests/vm/array_element_assignment.elna
index 4d76031..d1f00d6 100644
--- a/tests/vm/array_element_assignment.elna
+++ b/tests/vm/array_element_assignment.elna
@@ -1,3 +1,6 @@
 proc main() {
   var a: array[1] of int;
+  a[0] := 5;
+
+  printi(a[0]);
 }
diff --git a/tests/vm/print_array_element.elna b/tests/vm/print_array_element.elna
new file mode 100644
index 0000000..4c9d0cd
--- /dev/null
+++ b/tests/vm/print_array_element.elna
@@ -0,0 +1,8 @@
+proc main() {
+  var a: array[2] of int;
+  a[0] := 5;
+  a[1] := 7;
+
+  printi(a[0]);
+  printi(a[1]);
+}