781 lines
22 KiB
C++

/* Abstract syntax tree representation.
Copyright (C) 2025 Free Software Foundation, Inc.
GCC is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3, or (at your option)
any later version.
GCC is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with GCC; see the file COPYING3. If not see
<http://www.gnu.org/licenses/>. */
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include <variant>
#include "elna/boot/result.h"
namespace elna::boot
{
enum class binary_operator
{
sum,
subtraction,
multiplication,
division,
remainder,
equals,
not_equals,
less,
greater,
less_equal,
greater_equal,
disjunction,
conjunction,
exclusive_disjunction,
shift_left,
shift_right
};
enum class unary_operator
{
reference,
negation,
minus
};
class variable_declaration;
class constant_definition;
class procedure_definition;
class type_definition;
class procedure_call;
class cast_expression;
class assign_statement;
class if_statement;
class while_statement;
class return_statement;
class traits_expression;
class block;
class program;
class binary_expression;
class unary_expression;
class primitive_type_expression;
class array_type_expression;
class pointer_type_expression;
class record_type_expression;
class union_type_expression;
class procedure_type_expression;
class variable_expression;
class array_access_expression;
class field_access_expression;
class dereference_expression;
class designator_expression;
class literal_expression;
template<typename T>
class literal;
class defer_statement;
/**
* Interface for AST visitors.
*/
struct parser_visitor
{
virtual void visit(variable_declaration *) = 0;
virtual void visit(constant_definition *) = 0;
virtual void visit(procedure_definition *) = 0;
virtual void visit(type_definition *) = 0;
virtual void visit(procedure_call *) = 0;
virtual void visit(cast_expression *) = 0;
virtual void visit(traits_expression *) = 0;
virtual void visit(assign_statement *) = 0;
virtual void visit(if_statement *) = 0;
virtual void visit(while_statement *) = 0;
virtual void visit(return_statement *) = 0;
virtual void visit(defer_statement *) = 0;
virtual void visit(block *) = 0;
virtual void visit(program *) = 0;
virtual void visit(binary_expression *) = 0;
virtual void visit(unary_expression *) = 0;
virtual void visit(primitive_type_expression *) = 0;
virtual void visit(array_type_expression *) = 0;
virtual void visit(pointer_type_expression *) = 0;
virtual void visit(record_type_expression *) = 0;
virtual void visit(union_type_expression *) = 0;
virtual void visit(procedure_type_expression *) = 0;
virtual void visit(variable_expression *) = 0;
virtual void visit(array_access_expression *) = 0;
virtual void visit(field_access_expression *) = 0;
virtual void visit(dereference_expression *) = 0;
virtual void visit(literal<std::int32_t> *) = 0;
virtual void visit(literal<std::uint32_t> *) = 0;
virtual void visit(literal<double> *) = 0;
virtual void visit(literal<bool> *) = 0;
virtual void visit(literal<unsigned char> *) = 0;
virtual void visit(literal<std::nullptr_t> *) = 0;
virtual void visit(literal<std::string> *) = 0;
};
/**
* AST node.
*/
class node
{
const struct position source_position;
protected:
/**
* \param position Source code position.
*/
explicit node(const position position);
public:
virtual void accept(parser_visitor *visitor) = 0;
virtual ~node() = 0;
/**
* \return Node position in the source code.
*/
const struct position& position() const;
};
class statement : public virtual node
{
public:
virtual assign_statement *is_assign();
virtual if_statement *is_if();
virtual while_statement *is_while();
virtual return_statement *is_return();
virtual defer_statement *is_defer();
virtual procedure_call *is_call_statement();
};
class expression : public virtual node
{
public:
virtual cast_expression *is_cast();
virtual traits_expression *is_traits();
virtual binary_expression *is_binary();
virtual unary_expression *is_unary();
virtual designator_expression *is_designator();
virtual procedure_call *is_call_expression();
virtual literal_expression *is_literal();
};
/**
* Symbol definition.
*/
class definition : public node
{
protected:
definition(const struct position position, const std::string& identifier, const bool exported);
public:
const std::string identifier;
const bool exported;
};
/**
* Some type expression.
*/
class type_expression : public node, public std::enable_shared_from_this<type_expression>
{
public:
virtual std::shared_ptr<primitive_type_expression> is_primitive();
virtual std::shared_ptr<array_type_expression> is_array();
virtual std::shared_ptr<pointer_type_expression> is_pointer();
virtual std::shared_ptr<record_type_expression> is_record();
virtual std::shared_ptr<union_type_expression> is_union();
virtual std::shared_ptr<procedure_type_expression> is_procedure();
protected:
type_expression(const struct position position);
};
/**
* Expression defining a basic type.
*/
class primitive_type_expression : public type_expression
{
public:
const std::string name;
primitive_type_expression(const struct position position, const std::string& name);
void accept(parser_visitor *visitor) override;
std::shared_ptr<primitive_type_expression> is_primitive() override;
};
class array_type_expression : public type_expression
{
std::shared_ptr<type_expression> m_base;
public:
const std::uint32_t size;
array_type_expression(const struct position position,
std::shared_ptr<type_expression> base, const std::uint32_t size);
void accept(parser_visitor *visitor) override;
std::shared_ptr<array_type_expression> is_array() override;
type_expression& base();
};
class pointer_type_expression : public type_expression
{
std::shared_ptr<type_expression> m_base;
public:
pointer_type_expression(const struct position position, std::shared_ptr<type_expression> base);
void accept(parser_visitor *visitor) override;
std::shared_ptr<pointer_type_expression> is_pointer() override;
type_expression& base();
};
using field_declaration = std::pair<std::string, std::shared_ptr<type_expression>>;
class record_type_expression : public type_expression
{
public:
const std::vector<field_declaration> fields;
record_type_expression(const struct position position, std::vector<field_declaration>&& fields);
void accept(parser_visitor *visitor) override;
std::shared_ptr<record_type_expression> is_record() override;
};
class union_type_expression : public type_expression
{
public:
std::vector<field_declaration> fields;
union_type_expression(const struct position position, std::vector<field_declaration>&& fields);
void accept(parser_visitor *visitor) override;
std::shared_ptr<union_type_expression> is_union() override;
};
/**
* Variable declaration.
*/
class variable_declaration : public definition
{
std::shared_ptr<type_expression> m_variable_type;
public:
variable_declaration(const struct position position, const std::string& identifier,
std::shared_ptr<type_expression> variable_type, const bool exported = false);
void accept(parser_visitor *visitor) override;
type_expression& variable_type();
};
/**
* Literal expression.
*/
class literal_expression : public expression
{
public:
virtual literal<std::int32_t> *is_int() = 0;
virtual literal<std::uint32_t> *is_word() = 0;
virtual literal<double> *is_float() = 0;
virtual literal<bool> *is_bool() = 0;
virtual literal<unsigned char> *is_char() = 0;
virtual literal<std::nullptr_t> *is_nil() = 0;
virtual literal<std::string> *is_string() = 0;
literal_expression *is_literal() override;
protected:
literal_expression();
};
/**
* Constant definition.
*/
class constant_definition : public definition
{
literal_expression *m_body;
public:
constant_definition(const struct position position, const std::string& identifier,
const bool exported, literal_expression *body);
void accept(parser_visitor *visitor) override;
literal_expression& body();
virtual ~constant_definition() override;
};
/**
* Tags a procedure type as never returning.
*/
struct return_declaration
{
return_declaration() = default;
explicit return_declaration(std::shared_ptr<type_expression> type);
explicit return_declaration(std::monostate);
std::shared_ptr<type_expression> type{ nullptr };
bool no_return{ false };
};
/**
* Procedure type.
*/
class procedure_type_expression : public type_expression
{
public:
const return_declaration return_type;
std::vector<std::shared_ptr<type_expression>> parameters;
procedure_type_expression(const struct position position,
return_declaration return_type = return_declaration());
void accept(parser_visitor *visitor) override;
std::shared_ptr<procedure_type_expression> is_procedure() override;
};
/**
* Procedure definition.
*/
class procedure_definition : public definition
{
std::shared_ptr<procedure_type_expression> m_heading;
public:
block *const body;
std::vector<std::string> parameter_names;
procedure_definition(const struct position position, const std::string& identifier,
const bool exported, std::shared_ptr<procedure_type_expression> heading, block *body = nullptr);
void accept(parser_visitor *visitor) override;
procedure_type_expression& heading();
virtual ~procedure_definition() override;
};
/**
* Type definition.
*/
class type_definition : public definition
{
std::shared_ptr<type_expression> m_body;
public:
type_definition(const struct position position, const std::string& identifier,
const bool exported, std::shared_ptr<type_expression> expression);
void accept(parser_visitor *visitor) override;
type_expression& body();
};
/**
* Cast expression.
*/
class cast_expression : public expression
{
std::shared_ptr<type_expression> m_target;
expression *m_value;
public:
cast_expression(const struct position position, std::shared_ptr<type_expression> target, expression *value);
void accept(parser_visitor *visitor) override;
cast_expression *is_cast() override;
type_expression& target();
expression& value();
virtual ~cast_expression() override;
};
class traits_expression : public expression
{
public:
std::vector<std::shared_ptr<type_expression>> parameters;
const std::string name;
traits_expression(const struct position position, const std::string& name);
void accept(parser_visitor *visitor) override;
traits_expression *is_traits() override;
};
/**
* List of statements paired with a condition.
*/
class conditional_statements
{
expression *m_prerequisite;
public:
std::vector<statement *> statements;
conditional_statements(expression *prerequisite);
expression& prerequisite();
virtual ~conditional_statements();
};
class return_statement : public statement
{
expression *m_return_expression{ nullptr };
public:
return_statement(const struct position position, expression *return_expression = nullptr);
void accept(parser_visitor *visitor) override;
virtual return_statement *is_return() override;
expression *return_expression();
virtual ~return_statement() override;
};
class designator_expression : public expression
{
public:
virtual variable_expression *is_variable();
virtual array_access_expression *is_array_access();
virtual field_access_expression *is_field_access();
virtual dereference_expression *is_dereference();
designator_expression *is_designator() override;
void accept(parser_visitor *visitor);
~designator_expression() = 0;
protected:
designator_expression();
};
class variable_expression : public designator_expression
{
public:
const std::string name;
variable_expression(const struct position position, const std::string& name);
void accept(parser_visitor *visitor) override;
variable_expression *is_variable() override;
};
class array_access_expression : public designator_expression
{
expression *m_base;
expression *m_index;
public:
array_access_expression(const struct position position, expression *base, expression *index);
void accept(parser_visitor *visitor) override;
expression& base();
expression& index();
array_access_expression *is_array_access() override;
~array_access_expression() override;
};
class field_access_expression : public designator_expression
{
expression *m_base;
std::string m_field;
public:
field_access_expression(const struct position position, expression *base,
const std::string& field);
void accept(parser_visitor *visitor) override;
expression& base();
std::string& field();
field_access_expression *is_field_access() override;
~field_access_expression() override;
};
class dereference_expression : public designator_expression
{
expression *m_base;
public:
dereference_expression(const struct position position, expression *base);
void accept(parser_visitor *visitor) override;
expression& base();
dereference_expression *is_dereference() override;
~dereference_expression() override;
};
/**
* Procedure call expression.
*/
class procedure_call : public expression, public statement
{
designator_expression *m_callable;
public:
std::vector<expression *> arguments;
procedure_call(const struct position position, designator_expression *callable);
void accept(parser_visitor *visitor) override;
virtual procedure_call *is_call_statement() override;
virtual procedure_call *is_call_expression() override;
designator_expression& callable();
virtual ~procedure_call() override;
};
class assign_statement : public statement
{
designator_expression *m_lvalue;
expression *m_rvalue;
public:
/**
* \param position Source code position.
* \param lvalue Left-hand side.
* \param rvalue Assigned expression.
*/
assign_statement(const struct position position, designator_expression *lvalue,
expression *rvalue);
void accept(parser_visitor *visitor) override;
designator_expression& lvalue();
expression& rvalue();
virtual ~assign_statement() override;
assign_statement *is_assign() override;
};
/**
* If-statement.
*/
class if_statement : public statement
{
conditional_statements *m_body;
std::vector<statement *> *m_alternative;
public:
std::vector<conditional_statements *> branches;
if_statement(const struct position position, conditional_statements *body,
std::vector<statement *> *alternative = nullptr);
void accept(parser_visitor *visitor) override;
virtual if_statement *is_if() override;
conditional_statements& body();
std::vector<statement *> *alternative();
virtual ~if_statement() override;
};
/**
* While-statement.
*/
class while_statement : public statement
{
conditional_statements *m_body;
public:
std::vector<conditional_statements *> branches;
while_statement(const struct position position, conditional_statements *body);
void accept(parser_visitor *visitor) override;
while_statement *is_while() override;
conditional_statements& body();
virtual ~while_statement() override;
};
class block : public node
{
public:
std::vector<variable_declaration *> variables;
std::vector<constant_definition *> constants;
std::vector<statement *> body;
block(const struct position position);
virtual void accept(parser_visitor *visitor) override;
virtual ~block() override;
};
class program : public block
{
public:
std::vector<type_definition *> types;
std::vector<procedure_definition *> procedures;
program(const struct position position);
void accept(parser_visitor *visitor) override;
virtual ~program() override;
};
template<typename T>
class literal : public literal_expression
{
public:
T value;
literal(const struct position position, const T& value)
: node(position), value(value)
{
}
literal<std::int32_t> *is_int() override
{
if (std::is_same<T, std::int32_t>::value)
{
return reinterpret_cast<literal<std::int32_t> *>(this);
}
else
{
return nullptr;
}
}
literal<std::uint32_t> *is_word() override
{
if (std::is_same<T, std::uint32_t>::value)
{
return reinterpret_cast<literal<std::uint32_t> *>(this);
}
else
{
return nullptr;
}
}
literal<double> *is_float() override
{
if (std::is_same<T, double>::value)
{
return reinterpret_cast<literal<double> *>(this);
}
else
{
return nullptr;
}
}
literal<bool> *is_bool() override
{
if (std::is_same<T, bool>::value)
{
return reinterpret_cast<literal<bool> *>(this);
}
else
{
return nullptr;
}
}
literal<unsigned char> *is_char() override
{
if (std::is_same<T, unsigned char>::value)
{
return reinterpret_cast<literal<unsigned char> *>(this);
}
else
{
return nullptr;
}
}
literal<std::nullptr_t> *is_nil() override
{
if (std::is_same<T, std::nullptr_t>::value)
{
return reinterpret_cast<literal<std::nullptr_t> *>(this);
}
else
{
return nullptr;
}
}
literal<std::string> *is_string() override
{
if (std::is_same<T, std::string>::value)
{
return reinterpret_cast<literal<std::string> *>(this);
}
else
{
return nullptr;
}
}
void accept(parser_visitor *visitor) override
{
visitor->visit(this);
}
};
class defer_statement : public statement
{
public:
std::vector<statement *> statements;
defer_statement(const struct position position);
void accept(parser_visitor *visitor) override;
defer_statement *is_defer() override;
virtual ~defer_statement() override;
};
class binary_expression : public expression
{
expression *m_lhs;
expression *m_rhs;
binary_operator m_operator;
public:
binary_expression(const struct position position, expression *lhs,
expression *rhs, const binary_operator operation);
void accept(parser_visitor *visitor) override;
binary_expression *is_binary() override;
expression& lhs();
expression& rhs();
binary_operator operation() const;
virtual ~binary_expression() override;
};
class unary_expression : public expression
{
expression *m_operand;
unary_operator m_operator;
public:
unary_expression(const struct position position, expression *operand,
const unary_operator operation);
void accept(parser_visitor *visitor) override;
unary_expression *is_unary() override;
expression& operand();
unary_operator operation() const;
virtual ~unary_expression() override;
};
const char *print_binary_operator(const binary_operator operation);
}