dolphin/Source/Core/Common/Assembler/GekkoIRGen.cpp
2024-05-03 18:43:51 -07:00

834 lines
20 KiB
C++

// Copyright 2023 Dolphin Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#include "Common/Assembler/GekkoIRGen.h"
#include <bit>
#include <functional>
#include <map>
#include <numeric>
#include <set>
#include <stack>
#include <variant>
#include <vector>
#include <fmt/format.h>
#include "Common/Assembler/AssemblerShared.h"
#include "Common/Assembler/GekkoParser.h"
#include "Common/Assert.h"
#include "Common/BitUtils.h"
namespace Common::GekkoAssembler::detail
{
namespace
{
class GekkoIRPlugin : public ParsePlugin
{
public:
GekkoIRPlugin(GekkoIR& result, u32 base_addr)
: m_output_result(result), m_active_var(nullptr), m_operand_scan_begin(0)
{
m_active_block = &m_output_result.blocks.emplace_back(base_addr);
}
virtual ~GekkoIRPlugin() = default;
void OnDirectivePre(GekkoDirective directive) override;
void OnDirectivePost(GekkoDirective directive) override;
void OnInstructionPre(const ParseInfo& mnemonic_info, bool extended) override;
void OnInstructionPost(const ParseInfo& mnemonic_info, bool extended) override;
void OnOperandPre() override;
void OnOperandPost() override;
void OnResolvedExprPost() override;
void OnOperator(AsmOp operation) override;
void OnTerminal(Terminal type, const AssemblerToken& val) override;
void OnHiaddr(std::string_view id) override;
void OnLoaddr(std::string_view id) override;
void OnCloseParen(ParenType type) override;
void OnLabelDecl(std::string_view name) override;
void OnVarDecl(std::string_view name) override;
void PostParseAction() override;
u32 CurrentAddress() const;
std::optional<u64> LookupVar(std::string_view lab);
std::optional<u32> LookupLabel(std::string_view lab);
template <typename T>
T& GetChunk();
template <typename T>
void AddBytes(T val);
void AddStringBytes(std::string_view str, bool null_term);
void PadAlign(u32 bits);
void PadSpace(size_t space);
void StartBlock(u32 address);
void StartBlockAlign(u32 bits);
void StartInstruction(size_t mnemonic_index, bool extended);
void FinishInstruction();
void SaveOperandFixup(size_t str_left, size_t str_right);
void AddBinaryEvaluator(u32 (*evaluator)(u32, u32));
void AddUnaryEvaluator(u32 (*evaluator)(u32));
void AddAbsoluteAddressConv();
void AddLiteral(u32 lit);
void AddSymbolResolve(std::string_view sym, bool absolute);
void RunFixups();
void EvalOperatorRel(AsmOp operation);
void EvalOperatorAbs(AsmOp operation);
void EvalTerminalRel(Terminal type, const AssemblerToken& tok);
void EvalTerminalAbs(Terminal type, const AssemblerToken& tok);
private:
enum class EvalMode
{
RelAddrDoublePass,
AbsAddrSinglePass,
};
GekkoIR& m_output_result;
IRBlock* m_active_block;
GekkoInstruction m_build_inst;
u64* m_active_var;
size_t m_operand_scan_begin;
std::map<std::string, u32, std::less<>> m_labels;
std::map<std::string, u64, std::less<>> m_constants;
std::set<std::string> m_symset;
EvalMode m_evaluation_mode;
// For operand parsing
std::stack<std::function<u32()>> m_fixup_stack;
std::vector<std::function<u32()>> m_operand_fixups;
size_t m_operand_str_start;
// For directive parsing
std::vector<u64> m_eval_stack;
std::variant<std::vector<float>, std::vector<double>> m_floats_list;
std::string_view m_string_lit;
GekkoDirective m_active_directive;
};
///////////////
// OVERRIDES //
///////////////
void GekkoIRPlugin::OnDirectivePre(GekkoDirective directive)
{
m_evaluation_mode = EvalMode::AbsAddrSinglePass;
m_active_directive = directive;
m_eval_stack = std::vector<u64>{};
switch (directive)
{
case GekkoDirective::Float:
m_floats_list = std::vector<float>{};
break;
case GekkoDirective::Double:
m_floats_list = std::vector<double>{};
break;
default:
break;
}
}
void GekkoIRPlugin::OnDirectivePost(GekkoDirective directive)
{
switch (directive)
{
// .nbyte directives are handled by OnResolvedExprPost
default:
break;
case GekkoDirective::Float:
case GekkoDirective::Double:
std::visit(
[this](auto&& vec) {
for (auto&& val : vec)
{
AddBytes(val);
}
},
m_floats_list);
break;
case GekkoDirective::DefVar:
ASSERT(m_active_var != nullptr);
*m_active_var = m_eval_stack.back();
m_active_var = nullptr;
break;
case GekkoDirective::Locate:
StartBlock(static_cast<u32>(m_eval_stack.back()));
break;
case GekkoDirective::Zeros:
PadSpace(static_cast<u32>(m_eval_stack.back()));
break;
case GekkoDirective::Skip:
{
const u32 skip_len = static_cast<u32>(m_eval_stack.back());
if (skip_len > 0)
{
StartBlock(CurrentAddress() + skip_len);
}
break;
}
case GekkoDirective::PadAlign:
PadAlign(static_cast<u32>(m_eval_stack.back()));
break;
case GekkoDirective::Align:
StartBlockAlign(static_cast<u32>(m_eval_stack.back()));
break;
case GekkoDirective::Ascii:
AddStringBytes(m_string_lit, false);
break;
case GekkoDirective::Asciz:
AddStringBytes(m_string_lit, true);
break;
}
m_eval_stack = {};
}
void GekkoIRPlugin::OnInstructionPre(const ParseInfo& mnemonic_info, bool extended)
{
m_evaluation_mode = EvalMode::RelAddrDoublePass;
StartInstruction(mnemonic_info.mnemonic_index, extended);
}
void GekkoIRPlugin::OnInstructionPost(const ParseInfo&, bool)
{
FinishInstruction();
}
void GekkoIRPlugin::OnOperandPre()
{
m_operand_str_start = m_owner->lexer.ColNumber();
}
void GekkoIRPlugin::OnOperandPost()
{
SaveOperandFixup(m_operand_str_start, m_owner->lexer.ColNumber());
}
void GekkoIRPlugin::OnResolvedExprPost()
{
switch (m_active_directive)
{
case GekkoDirective::Byte:
AddBytes<u8>(static_cast<u8>(m_eval_stack.back()));
break;
case GekkoDirective::_2byte:
AddBytes<u16>(static_cast<u16>(m_eval_stack.back()));
break;
case GekkoDirective::_4byte:
AddBytes<u32>(static_cast<u32>(m_eval_stack.back()));
break;
case GekkoDirective::_8byte:
AddBytes<u64>(static_cast<u64>(m_eval_stack.back()));
break;
default:
return;
}
m_eval_stack.clear();
}
void GekkoIRPlugin::OnOperator(AsmOp operation)
{
if (m_evaluation_mode == EvalMode::RelAddrDoublePass)
{
EvalOperatorRel(operation);
}
else
{
EvalOperatorAbs(operation);
}
}
void GekkoIRPlugin::OnTerminal(Terminal type, const AssemblerToken& val)
{
if (type == Terminal::Str)
{
m_string_lit = val.token_val;
}
else if (m_evaluation_mode == EvalMode::RelAddrDoublePass)
{
EvalTerminalRel(type, val);
}
else
{
EvalTerminalAbs(type, val);
}
}
void GekkoIRPlugin::OnHiaddr(std::string_view id)
{
if (m_evaluation_mode == EvalMode::RelAddrDoublePass)
{
AddSymbolResolve(id, true);
AddLiteral(16);
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs >> rhs; });
AddLiteral(0xffff);
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs & rhs; });
}
else
{
u32 base;
if (auto lbl = LookupLabel(id); lbl)
{
base = *lbl;
}
else if (auto var = LookupVar(id); var)
{
base = *var;
}
else
{
m_owner->EmitErrorHere(fmt::format("Undefined reference to Label/Constant '{}'", id));
return;
}
m_eval_stack.push_back((base >> 16) & 0xffff);
}
}
void GekkoIRPlugin::OnLoaddr(std::string_view id)
{
if (m_evaluation_mode == EvalMode::RelAddrDoublePass)
{
AddSymbolResolve(id, true);
AddLiteral(0xffff);
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs & rhs; });
}
else
{
u32 base;
if (auto lbl = LookupLabel(id); lbl)
{
base = *lbl;
}
else if (auto var = LookupVar(id); var)
{
base = *var;
}
else
{
m_owner->EmitErrorHere(fmt::format("Undefined reference to Label/Constant '{}'", id));
return;
}
m_eval_stack.push_back(base & 0xffff);
}
}
void GekkoIRPlugin::OnCloseParen(ParenType type)
{
if (type != ParenType::RelConv)
{
return;
}
if (m_evaluation_mode == EvalMode::RelAddrDoublePass)
{
AddAbsoluteAddressConv();
}
else
{
m_eval_stack.push_back(CurrentAddress());
EvalOperatorAbs(AsmOp::Sub);
}
}
void GekkoIRPlugin::OnLabelDecl(std::string_view name)
{
const std::string name_str(name);
if (m_symset.contains(name_str))
{
m_owner->EmitErrorHere(fmt::format("Label/Constant {} is already defined", name));
return;
}
m_labels[name_str] = m_active_block->BlockEndAddress();
m_symset.insert(name_str);
}
void GekkoIRPlugin::OnVarDecl(std::string_view name)
{
const std::string name_str(name);
if (m_symset.contains(name_str))
{
m_owner->EmitErrorHere(fmt::format("Label/Constant {} is already defined", name));
return;
}
m_active_var = &m_constants[name_str];
m_symset.insert(name_str);
}
void GekkoIRPlugin::PostParseAction()
{
RunFixups();
}
//////////////////////
// HELPER FUNCTIONS //
//////////////////////
u32 GekkoIRPlugin::CurrentAddress() const
{
return m_active_block->BlockEndAddress();
}
std::optional<u64> GekkoIRPlugin::LookupVar(std::string_view var)
{
auto var_it = m_constants.find(var);
return var_it == m_constants.end() ? std::nullopt : std::optional(var_it->second);
}
std::optional<u32> GekkoIRPlugin::LookupLabel(std::string_view lab)
{
auto label_it = m_labels.find(lab);
return label_it == m_labels.end() ? std::nullopt : std::optional(label_it->second);
}
void GekkoIRPlugin::AddStringBytes(std::string_view str, bool null_term)
{
ByteChunk& bytes = GetChunk<ByteChunk>();
ConvertStringLiteral(str, &bytes);
if (null_term)
{
bytes.push_back('\0');
}
}
template <typename T>
T& GekkoIRPlugin::GetChunk()
{
if (!m_active_block->chunks.empty() && std::holds_alternative<T>(m_active_block->chunks.back()))
{
return std::get<T>(m_active_block->chunks.back());
}
return std::get<T>(m_active_block->chunks.emplace_back(T{}));
}
template <typename T>
void GekkoIRPlugin::AddBytes(T val)
{
if constexpr (std::is_integral_v<T>)
{
ByteChunk& bytes = GetChunk<ByteChunk>();
for (size_t i = sizeof(T) - 1; i > 0; i--)
{
bytes.push_back((val >> (8 * i)) & 0xff);
}
bytes.push_back(val & 0xff);
}
else if constexpr (std::is_same_v<T, float>)
{
static_assert(sizeof(double) == sizeof(u64));
AddBytes(std::bit_cast<u32>(val));
}
else
{
// std::is_same_v<T, double>
static_assert(sizeof(double) == sizeof(u64));
AddBytes(std::bit_cast<u64>(val));
}
}
void GekkoIRPlugin::PadAlign(u32 bits)
{
const u32 align_mask = (1 << bits) - 1;
const u32 current_addr = m_active_block->BlockEndAddress();
if (current_addr & align_mask)
{
PadChunk& current_pad = GetChunk<PadChunk>();
current_pad += (1 << bits) - (current_addr & align_mask);
}
}
void GekkoIRPlugin::PadSpace(size_t space)
{
GetChunk<PadChunk>() += space;
}
void GekkoIRPlugin::StartBlock(u32 address)
{
m_active_block = &m_output_result.blocks.emplace_back(address);
}
void GekkoIRPlugin::StartBlockAlign(u32 bits)
{
const u32 align_mask = (1 << bits) - 1;
const u32 current_addr = m_active_block->BlockEndAddress();
if (current_addr & align_mask)
{
StartBlock((1 << bits) + (current_addr & ~align_mask));
}
}
void GekkoIRPlugin::StartInstruction(size_t mnemonic_index, bool extended)
{
m_build_inst = GekkoInstruction{
.mnemonic_index = mnemonic_index,
.raw_text = m_owner->lexer.CurrentLine(),
.line_number = m_owner->lexer.LineNumber(),
.is_extended = extended,
};
m_operand_scan_begin = m_output_result.operand_pool.size();
}
void GekkoIRPlugin::AddBinaryEvaluator(u32 (*evaluator)(u32, u32))
{
std::function<u32()> rhs = std::move(m_fixup_stack.top());
m_fixup_stack.pop();
std::function<u32()> lhs = std::move(m_fixup_stack.top());
m_fixup_stack.pop();
m_fixup_stack.emplace([evaluator, lhs = std::move(lhs), rhs = std::move(rhs)]() {
return evaluator(lhs(), rhs());
});
}
void GekkoIRPlugin::AddUnaryEvaluator(u32 (*evaluator)(u32))
{
std::function<u32()> sub = std::move(m_fixup_stack.top());
m_fixup_stack.pop();
m_fixup_stack.emplace([evaluator, sub = std::move(sub)]() { return evaluator(sub()); });
}
void GekkoIRPlugin::AddAbsoluteAddressConv()
{
const u32 inst_address = m_active_block->BlockEndAddress();
std::function<u32()> sub = std::move(m_fixup_stack.top());
m_fixup_stack.pop();
m_fixup_stack.emplace([inst_address, sub = std::move(sub)] { return sub() - inst_address; });
}
void GekkoIRPlugin::AddLiteral(u32 lit)
{
m_fixup_stack.emplace([lit] { return lit; });
}
void GekkoIRPlugin::AddSymbolResolve(std::string_view sym, bool absolute)
{
const u32 source_address = m_active_block->BlockEndAddress();
AssemblerError err_on_fail = AssemblerError{
fmt::format("Unresolved symbol '{}'", sym),
m_owner->lexer.CurrentLine(),
m_owner->lexer.LineNumber(),
// Lexer should currently point to the label, as it hasn't been eaten yet
m_owner->lexer.ColNumber(),
sym.size(),
};
m_fixup_stack.emplace(
[this, sym, absolute, source_address, err_on_fail = std::move(err_on_fail)] {
auto label_it = m_labels.find(sym);
if (label_it != m_labels.end())
{
if (absolute)
{
return label_it->second;
}
return label_it->second - source_address;
}
auto var_it = m_constants.find(sym);
if (var_it != m_constants.end())
{
return static_cast<u32>(var_it->second);
}
m_owner->error = std::move(err_on_fail);
return u32{0};
});
}
void GekkoIRPlugin::SaveOperandFixup(size_t str_left, size_t str_right)
{
m_operand_fixups.emplace_back(std::move(m_fixup_stack.top()));
m_fixup_stack.pop();
m_output_result.operand_pool.emplace_back(Interval{str_left, str_right - str_left}, 0);
}
void GekkoIRPlugin::RunFixups()
{
for (size_t i = 0; i < m_operand_fixups.size(); i++)
{
ValueOf(m_output_result.operand_pool[i]) = m_operand_fixups[i]();
if (m_owner->error)
{
return;
}
}
}
void GekkoIRPlugin::FinishInstruction()
{
m_build_inst.op_interval.begin = m_operand_scan_begin;
m_build_inst.op_interval.len = m_output_result.operand_pool.size() - m_operand_scan_begin;
GetChunk<InstChunk>().emplace_back(m_build_inst);
m_operand_scan_begin = 0;
}
void GekkoIRPlugin::EvalOperatorAbs(AsmOp operation)
{
#define EVAL_BINARY_OP(OPERATOR) \
{ \
u64 rhs = m_eval_stack.back(); \
m_eval_stack.pop_back(); \
m_eval_stack.back() = m_eval_stack.back() OPERATOR rhs; \
}
switch (operation)
{
case AsmOp::Or:
EVAL_BINARY_OP(|);
break;
case AsmOp::Xor:
EVAL_BINARY_OP(^);
break;
case AsmOp::And:
EVAL_BINARY_OP(&);
break;
case AsmOp::Lsh:
EVAL_BINARY_OP(<<);
break;
case AsmOp::Rsh:
EVAL_BINARY_OP(>>);
break;
case AsmOp::Add:
EVAL_BINARY_OP(+);
break;
case AsmOp::Sub:
EVAL_BINARY_OP(-);
break;
case AsmOp::Mul:
EVAL_BINARY_OP(*);
break;
case AsmOp::Div:
EVAL_BINARY_OP(/);
break;
case AsmOp::Neg:
m_eval_stack.back() = static_cast<u32>(-static_cast<s32>(m_eval_stack.back()));
break;
case AsmOp::Not:
m_eval_stack.back() = ~m_eval_stack.back();
break;
}
#undef EVAL_BINARY_OP
#undef EVAL_UNARY_OP
}
void GekkoIRPlugin::EvalOperatorRel(AsmOp operation)
{
switch (operation)
{
case AsmOp::Or:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs | rhs; });
break;
case AsmOp::Xor:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs ^ rhs; });
break;
case AsmOp::And:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs & rhs; });
break;
case AsmOp::Lsh:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs << rhs; });
break;
case AsmOp::Rsh:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs >> rhs; });
break;
case AsmOp::Add:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs + rhs; });
break;
case AsmOp::Sub:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs - rhs; });
break;
case AsmOp::Mul:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs * rhs; });
break;
case AsmOp::Div:
AddBinaryEvaluator([](u32 lhs, u32 rhs) { return lhs / rhs; });
break;
case AsmOp::Neg:
AddUnaryEvaluator([](u32 val) { return static_cast<u32>(-static_cast<s32>(val)); });
break;
case AsmOp::Not:
AddUnaryEvaluator([](u32 val) { return ~val; });
break;
}
}
void GekkoIRPlugin::EvalTerminalRel(Terminal type, const AssemblerToken& tok)
{
switch (type)
{
case Terminal::Hex:
case Terminal::Dec:
case Terminal::Oct:
case Terminal::Bin:
case Terminal::GPR:
case Terminal::FPR:
case Terminal::SPR:
case Terminal::CRField:
case Terminal::Lt:
case Terminal::Gt:
case Terminal::Eq:
case Terminal::So:
{
std::optional<u32> val = tok.EvalToken<u32>();
ASSERT(val.has_value());
AddLiteral(*val);
break;
}
case Terminal::Dot:
AddLiteral(CurrentAddress());
break;
case Terminal::Id:
{
if (auto label_it = m_labels.find(tok.token_val); label_it != m_labels.end())
{
AddLiteral(label_it->second - CurrentAddress());
}
else if (auto var_it = m_constants.find(tok.token_val); var_it != m_constants.end())
{
AddLiteral(var_it->second);
}
else
{
AddSymbolResolve(tok.token_val, false);
}
break;
}
// Parser should disallow this from happening
default:
ASSERT(false);
break;
}
}
void GekkoIRPlugin::EvalTerminalAbs(Terminal type, const AssemblerToken& tok)
{
switch (type)
{
case Terminal::Hex:
case Terminal::Dec:
case Terminal::Oct:
case Terminal::Bin:
case Terminal::GPR:
case Terminal::FPR:
case Terminal::SPR:
case Terminal::CRField:
case Terminal::Lt:
case Terminal::Gt:
case Terminal::Eq:
case Terminal::So:
{
std::optional<u64> val = tok.EvalToken<u64>();
ASSERT(val.has_value());
m_eval_stack.push_back(*val);
break;
}
case Terminal::Flt:
{
std::visit(
[&tok](auto&& vec) {
auto opt = tok.EvalToken<typename std::decay_t<decltype(vec)>::value_type>();
ASSERT(opt.has_value());
vec.push_back(*opt);
},
m_floats_list);
break;
}
case Terminal::Dot:
m_eval_stack.push_back(static_cast<u64>(CurrentAddress()));
break;
case Terminal::Id:
{
if (auto label_it = m_labels.find(tok.token_val); label_it != m_labels.end())
{
m_eval_stack.push_back(label_it->second);
}
else if (auto var_it = m_constants.find(tok.token_val); var_it != m_constants.end())
{
m_eval_stack.push_back(var_it->second);
}
else
{
m_owner->EmitErrorHere(
fmt::format("Undefined reference to Label/Constant '{}'", tok.ValStr()));
return;
}
break;
}
// Parser should disallow this from happening
default:
ASSERT(false);
break;
}
}
} // namespace
u32 IRBlock::BlockEndAddress() const
{
return std::accumulate(chunks.begin(), chunks.end(), block_address,
[](u32 acc, const ChunkVariant& chunk) {
size_t size;
if (std::holds_alternative<InstChunk>(chunk))
{
size = std::get<InstChunk>(chunk).size() * 4;
}
else if (std::holds_alternative<ByteChunk>(chunk))
{
size = std::get<ByteChunk>(chunk).size();
}
else if (std::holds_alternative<PadChunk>(chunk))
{
size = std::get<PadChunk>(chunk);
}
else
{
ASSERT(false);
size = 0;
}
return acc + static_cast<u32>(size);
});
}
FailureOr<GekkoIR> ParseToIR(std::string_view assembly, u32 base_virtual_address)
{
GekkoIR ret;
GekkoIRPlugin plugin(ret, base_virtual_address);
ParseWithPlugin(&plugin, assembly);
if (plugin.Error())
{
return FailureOr<GekkoIR>(std::move(*plugin.Error()));
}
return std::move(ret);
}
} // namespace Common::GekkoAssembler::detail