/* Name analysis.
   Copyright (C) 2025 Free Software Foundation, Inc.

GCC is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3, or (at your option)
any later version.

GCC is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with GCC; see the file COPYING3.  If not see
<http://www.gnu.org/licenses/>.  */

#include "elna/boot/semantic.h"

namespace elna::boot
{
    undeclared_error::undeclared_error(const std::string& identifier, const char *path, const struct position position)
        : error(path, position), identifier(identifier)
    {
    }

    std::string undeclared_error::what() const
    {
        return "Type '" + identifier + "' not declared";
    }


    already_declared_error::already_declared_error(const std::string& identifier,
            const char *path, const struct position position)
        : error(path, position), identifier(identifier)
    {
    }

    std::string already_declared_error::what() const
    {
        return "Symbol '" + identifier + "' has been already declared";
    }

    declaration_visitor::declaration_visitor(const char *path, std::shared_ptr<symbol_table> symbols)
        : error_container(path), symbols(symbols)
    {
    }

    void declaration_visitor::visit(program *program)
    {
        for (type_definition *const type : program->types)
        {
            if (!this->unresolved.insert({ type->identifier, std::make_shared<alias_type>(type->identifier) }).second
                    || this->symbols->contains(type->identifier))
            {
                add_error<already_declared_error>(type->identifier, this->input_file, type->position());
            }
        }
        for (type_definition *const type : program->types)
        {
            type->accept(this);
        }
        for (auto& unresolved : this->unresolved)
        {
            auto info = std::make_shared<type_info>(type_info(type(unresolved.second)));
            this->symbols->enter(std::move(unresolved.first), info);
        }
        for (variable_declaration *const variable : program->variables)
        {
            variable->accept(this);
        }
        for (procedure_definition *const procedure : program->procedures)
        {
            procedure->accept(this);
        }
        for (statement *const statement : program->body)
        {
            statement->accept(this);
        }
    }

    void declaration_visitor::visit(type_definition *definition)
    {
        definition->body().accept(this);
        auto unresolved_declaration = this->unresolved.at(definition->identifier);

        unresolved_declaration->reference = this->current_type;
    }

    void declaration_visitor::visit(primitive_type_expression *type_expression)
    {
        auto unresolved_alias = this->unresolved.find(type_expression->name);

        if (unresolved_alias != this->unresolved.end())
        {
            this->current_type = type(unresolved_alias->second);
        }
        else if (auto from_symbol_table = this->symbols->lookup(type_expression->name))
        {
            this->current_type = from_symbol_table->is_type()->symbol;
        }
        else
        {
            add_error<undeclared_error>(type_expression->name, this->input_file, type_expression->position());
            this->current_type = type();
        }
    }

    void declaration_visitor::visit(pointer_type_expression *type_expression)
    {
        type_expression->base().accept(this);
        this->current_type = type(std::make_shared<pointer_type>(this->current_type));
    }

    void declaration_visitor::visit(array_type_expression *type_expression)
    {
        type_expression->base().accept(this);
        this->current_type = type(std::make_shared<array_type>(this->current_type, type_expression->size));
    }

    void declaration_visitor::visit(record_type_expression *)
    {
        this->current_type = type(std::make_shared<record_type>());
    }

    void declaration_visitor::visit(union_type_expression *)
    {
        this->current_type = type(std::make_shared<union_type>());
    }

    void declaration_visitor::visit(procedure_type_expression *)
    {
    }

    void declaration_visitor::visit(variable_declaration *declaration)
    {
        declaration->variable_type().accept(this);
    }

    void declaration_visitor::visit(constant_definition *)
    {
    }

    void declaration_visitor::visit(procedure_definition *definition)
    {
        for (auto heading_parameter : definition->heading().parameters)
        {
            heading_parameter->accept(this);
        }
        if (definition->heading().return_type.type != nullptr)
        {
            definition->heading().return_type.type->accept(this);
        }
        if (definition->body != nullptr)
        {
            definition->body->accept(this);
        }
    }

    void declaration_visitor::visit(assign_statement *statement)
    {
        statement->lvalue().accept(this);
        statement->rvalue().accept(this);
    }

    void declaration_visitor::visit(if_statement *statement)
    {
        statement->body().prerequisite().accept(this);
        for (struct statement *const statement : statement->body().statements)
        {
            statement->accept(this);
        }
        for (const auto branch : statement->branches)
        {
            branch->prerequisite().accept(this);

            for (struct statement *const statement : branch->statements)
            {
                statement->accept(this);
            }
        }
        if (statement->alternative() != nullptr)
        {
            for (struct statement *const statement : *statement->alternative())
            {
                statement->accept(this);
            }
        }
    }

    void declaration_visitor::visit(while_statement *statement)
    {
        statement->body().prerequisite().accept(this);
        for (struct statement *const statement : statement->body().statements)
        {
            statement->accept(this);
        }
        for (const auto branch : statement->branches)
        {
            branch->prerequisite().accept(this);

            for (struct statement *const statement : branch->statements)
            {
                statement->accept(this);
            }
        }
    }

    void declaration_visitor::visit(return_statement *statement)
    {
        if (statement->return_expression() != nullptr)
        {
            statement->return_expression()->accept(this);
        }
    }

    void declaration_visitor::visit(defer_statement *statement)
    {
        for (struct statement *const statement : statement->statements)
        {
            statement->accept(this);
        }
    }

    void declaration_visitor::visit(procedure_call *call)
    {
        call->callable().accept(this);
        for (expression *const argument: call->arguments)
        {
            argument->accept(this);
        }
    }

    void declaration_visitor::visit(block *block)
    {
        for (constant_definition *const constant : block->constants)
        {
            constant->accept(this);
        }
        for (variable_declaration *const variable : block->variables)
        {
            variable->accept(this);
        }
        for (statement *const statement : block->body)
        {
            statement->accept(this);
        }
    }

    void declaration_visitor::visit(traits_expression *trait)
    {
        if (!trait->parameters.empty())
        {
            trait->parameters.front()->accept(this);
        }
    }

    void declaration_visitor::visit(cast_expression *expression)
    {
        expression->value().accept(this);
        expression->target().accept(this);
    }

    void declaration_visitor::visit(binary_expression *expression)
    {
        expression->lhs().accept(this);
        expression->rhs().accept(this);
    }

    void declaration_visitor::visit(unary_expression *expression)
    {
        expression->operand().accept(this);
    }

    void declaration_visitor::visit(variable_expression *)
    {
    }

    void declaration_visitor::visit(array_access_expression *expression)
    {
        expression->base().accept(this);
        expression->index().accept(this);
    }

    void declaration_visitor::visit(field_access_expression *expression)
    {
        expression->base().accept(this);
    }

    void declaration_visitor::visit(dereference_expression *expression)
    {
        expression->base().accept(this);
    }

    void declaration_visitor::visit(literal<std::int32_t> *)
    {
    }

    void declaration_visitor::visit(literal<std::uint32_t> *)
    {
    }

    void declaration_visitor::visit(literal<double> *)
    {
    }

    void declaration_visitor::visit(literal<bool> *)
    {
    }

    void declaration_visitor::visit(literal<unsigned char> *)
    {
    }

    void declaration_visitor::visit(literal<std::nullptr_t> *)
    {
    }

    void declaration_visitor::visit(literal<std::string> *)
    {
    }
}