dolphin/Source/Core/Common/Assembler/GekkoLexer.cpp
2024-06-06 13:05:45 -06:00

794 lines
16 KiB
C++

// Copyright 2023 Dolphin Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#include "Common/Assembler/GekkoLexer.h"
#include "Common/Assert.h"
#include "Common/StringUtil.h"
#include <iterator>
#include <numeric>
namespace Common::GekkoAssembler::detail
{
namespace
{
constexpr bool IsOctal(char c)
{
return c >= '0' && c <= '7';
}
constexpr bool IsBinary(char c)
{
return c == '0' || c == '1';
}
template <typename T>
constexpr T ConvertNib(char c)
{
if (c >= 'a' && c <= 'f')
{
return static_cast<T>(c - 'a' + 10);
}
if (c >= 'A' && c <= 'F')
{
return static_cast<T>(c - 'A' + 10);
}
return static_cast<T>(c - '0');
}
constexpr TokenType SingleCharToken(char ch)
{
switch (ch)
{
case ',':
return TokenType::Comma;
case '(':
return TokenType::Lparen;
case ')':
return TokenType::Rparen;
case '|':
return TokenType::Pipe;
case '^':
return TokenType::Caret;
case '&':
return TokenType::Ampersand;
case '+':
return TokenType::Plus;
case '-':
return TokenType::Minus;
case '*':
return TokenType::Star;
case '/':
return TokenType::Slash;
case '~':
return TokenType::Tilde;
case '@':
return TokenType::At;
case ':':
return TokenType::Colon;
case '`':
return TokenType::Grave;
case '.':
return TokenType::Dot;
case '\0':
return TokenType::Eof;
case '\n':
return TokenType::Eol;
default:
return TokenType::Invalid;
}
}
// Convert a string literal into its raw-data form
template <typename Cont>
void ConvertStringLiteral(std::string_view literal, std::back_insert_iterator<Cont> out_it)
{
for (size_t i = 1; i < literal.size() - 1;)
{
if (literal[i] == '\\')
{
++i;
if (IsOctal(literal[i]))
{
// Octal escape
char octal_escape = 0;
for (char c = literal[i]; IsOctal(c); c = literal[++i])
{
octal_escape = (octal_escape << 3) + (c - '0');
}
out_it = static_cast<u8>(octal_escape);
}
else if (literal[i] == 'x')
{
// Hex escape
char hex_escape = 0;
for (char c = literal[++i]; std::isxdigit(c); c = literal[++i])
{
hex_escape = (hex_escape << 4) + ConvertNib<char>(c);
}
out_it = static_cast<u8>(hex_escape);
}
else
{
char simple_escape;
switch (literal[i])
{
case '\'':
simple_escape = '\x27';
break;
case '"':
simple_escape = '\x22';
break;
case '?':
simple_escape = '\x3f';
break;
case '\\':
simple_escape = '\x5c';
break;
case 'a':
simple_escape = '\x07';
break;
case 'b':
simple_escape = '\x08';
break;
case 'f':
simple_escape = '\x0c';
break;
case 'n':
simple_escape = '\x0a';
break;
case 'r':
simple_escape = '\x0d';
break;
case 't':
simple_escape = '\x09';
break;
case 'v':
simple_escape = '\x0b';
break;
default:
simple_escape = literal[i];
break;
}
out_it = static_cast<u8>(simple_escape);
++i;
}
}
else
{
out_it = static_cast<u8>(literal[i]);
++i;
}
}
}
template <typename T>
std::optional<T> EvalIntegral(TokenType tp, std::string_view val)
{
constexpr auto hex_step = [](T acc, char c) { return acc << 4 | ConvertNib<T>(c); };
constexpr auto dec_step = [](T acc, char c) { return acc * 10 + (c - '0'); };
constexpr auto oct_step = [](T acc, char c) { return acc << 3 | (c - '0'); };
constexpr auto bin_step = [](T acc, char c) { return acc << 1 | (c - '0'); };
switch (tp)
{
case TokenType::HexadecimalLit:
return std::accumulate(val.begin() + 2, val.end(), T{0}, hex_step);
case TokenType::DecimalLit:
return std::accumulate(val.begin(), val.end(), T{0}, dec_step);
case TokenType::OctalLit:
return std::accumulate(val.begin() + 1, val.end(), T{0}, oct_step);
case TokenType::BinaryLit:
return std::accumulate(val.begin() + 2, val.end(), T{0}, bin_step);
case TokenType::GPR:
if (CaseInsensitiveEquals(val, "sp"))
return T{1};
if (CaseInsensitiveEquals(val, "rtoc"))
return T{2};
[[fallthrough]];
case TokenType::FPR:
return std::accumulate(val.begin() + 1, val.end(), T{0}, dec_step);
case TokenType::CRField:
return std::accumulate(val.begin() + 2, val.end(), T{0}, dec_step);
case TokenType::SPR:
return static_cast<T>(*sprg_map.Find(val));
case TokenType::Lt:
return T{0};
case TokenType::Gt:
return T{1};
case TokenType::Eq:
return T{2};
case TokenType::So:
return T{3};
default:
return std::nullopt;
}
}
} // namespace
void ConvertStringLiteral(std::string_view literal, std::vector<u8>* out_vec)
{
ConvertStringLiteral(literal, std::back_inserter(*out_vec));
}
std::string_view TokenTypeToStr(TokenType tp)
{
switch (tp)
{
case TokenType::GPR:
return "GPR";
case TokenType::FPR:
return "FPR";
case TokenType::SPR:
return "SPR";
case TokenType::CRField:
return "CR Field";
case TokenType::Lt:
case TokenType::Gt:
case TokenType::Eq:
case TokenType::So:
return "CR Bit";
case TokenType::Identifier:
return "Identifier";
case TokenType::StringLit:
return "String Literal";
case TokenType::DecimalLit:
return "Decimal Literal";
case TokenType::BinaryLit:
return "Binary Literal";
case TokenType::HexadecimalLit:
return "Hexadecimal Literal";
case TokenType::OctalLit:
return "Octal Literal";
case TokenType::FloatLit:
return "Float Literal";
case TokenType::Invalid:
return "Invalid";
case TokenType::Lsh:
return "<<";
case TokenType::Rsh:
return ">>";
case TokenType::Comma:
return ",";
case TokenType::Lparen:
return "(";
case TokenType::Rparen:
return ")";
case TokenType::Pipe:
return "|";
case TokenType::Caret:
return "^";
case TokenType::Ampersand:
return "&";
case TokenType::Plus:
return "+";
case TokenType::Minus:
return "-";
case TokenType::Star:
return "*";
case TokenType::Slash:
return "/";
case TokenType::Tilde:
return "~";
case TokenType::At:
return "@";
case TokenType::Colon:
return ":";
case TokenType::Grave:
return "`";
case TokenType::Dot:
return ".";
case TokenType::Eof:
return "End of File";
case TokenType::Eol:
return "End of Line";
default:
return "";
}
}
std::string_view AssemblerToken::TypeStr() const
{
return TokenTypeToStr(token_type);
}
std::string_view AssemblerToken::ValStr() const
{
switch (token_type)
{
case TokenType::Eol:
return "<EOL>";
case TokenType::Eof:
return "<EOF>";
default:
return token_val;
}
}
template <>
std::optional<float> AssemblerToken::EvalToken() const
{
if (token_type == TokenType::FloatLit)
{
return std::stof(std::string(token_val));
}
return std::nullopt;
}
template <>
std::optional<double> AssemblerToken::EvalToken() const
{
if (token_type == TokenType::FloatLit)
{
return std::stod(std::string(token_val));
}
return std::nullopt;
}
template <>
std::optional<u8> AssemblerToken::EvalToken() const
{
return EvalIntegral<u8>(token_type, token_val);
}
template <>
std::optional<u16> AssemblerToken::EvalToken() const
{
return EvalIntegral<u16>(token_type, token_val);
}
template <>
std::optional<u32> AssemblerToken::EvalToken() const
{
return EvalIntegral<u32>(token_type, token_val);
}
template <>
std::optional<u64> AssemblerToken::EvalToken() const
{
return EvalIntegral<u64>(token_type, token_val);
}
size_t Lexer::LineNumber() const
{
return m_lexed_tokens.empty() ? m_pos.line : TagOf(m_lexed_tokens.front()).line;
}
size_t Lexer::ColNumber() const
{
return m_lexed_tokens.empty() ? m_pos.col : TagOf(m_lexed_tokens.front()).col;
}
std::string_view Lexer::CurrentLine() const
{
const size_t line_index =
m_lexed_tokens.empty() ? m_pos.index : TagOf(m_lexed_tokens.front()).index;
size_t begin_index = line_index == 0 ? 0 : line_index - 1;
for (; begin_index > 0; begin_index--)
{
if (m_lex_string[begin_index] == '\n')
{
begin_index++;
break;
}
}
size_t end_index = begin_index;
for (; end_index < m_lex_string.size(); end_index++)
{
if (m_lex_string[end_index] == '\n')
{
end_index++;
break;
}
}
return m_lex_string.substr(begin_index, end_index - begin_index);
}
void Lexer::SetIdentifierMatchRule(IdentifierMatchRule set)
{
FeedbackTokens();
m_match_rule = set;
}
const Tagged<CursorPosition, AssemblerToken>& Lexer::LookaheadTagRef(size_t num_fwd) const
{
while (m_lexed_tokens.size() < num_fwd)
{
LookaheadRef();
}
return m_lexed_tokens[num_fwd];
}
AssemblerToken Lexer::Lookahead() const
{
if (m_lexed_tokens.empty())
{
CursorPosition pos_pre = m_pos;
m_lexed_tokens.emplace_back(pos_pre, LexSingle());
}
return ValueOf(m_lexed_tokens.front());
}
const AssemblerToken& Lexer::LookaheadRef() const
{
if (m_lexed_tokens.empty())
{
CursorPosition pos_pre = m_pos;
m_lexed_tokens.emplace_back(pos_pre, LexSingle());
}
return ValueOf(m_lexed_tokens.front());
}
TokenType Lexer::LookaheadType() const
{
return LookaheadRef().token_type;
}
AssemblerToken Lexer::LookaheadFloat() const
{
FeedbackTokens();
SkipWs();
CursorPosition pos_pre = m_pos;
ScanStart();
std::optional<std::string_view> failure_reason = RunDfa(float_dfa);
// Special case: lex at least a single char for no matches for errors to make sense
if (m_scan_pos.index == pos_pre.index)
{
Step();
}
std::string_view tok_str = ScanFinishOut();
AssemblerToken tok;
if (!failure_reason)
{
tok = AssemblerToken{
TokenType::FloatLit,
tok_str,
"",
Interval{0, 0},
};
}
else
{
tok = AssemblerToken{
TokenType::Invalid,
tok_str,
*failure_reason,
Interval{0, tok_str.length()},
};
}
m_lexed_tokens.emplace_back(pos_pre, tok);
return tok;
}
void Lexer::Eat()
{
if (m_lexed_tokens.empty())
{
LexSingle();
}
else
{
m_lexed_tokens.pop_front();
}
}
void Lexer::EatAndReset()
{
Eat();
SetIdentifierMatchRule(IdentifierMatchRule::Typical);
}
std::optional<std::string_view> Lexer::RunDfa(const std::vector<DfaNode>& dfa) const
{
size_t dfa_index = 0;
bool transition_found;
do
{
transition_found = false;
if (Peek() == '\0')
{
break;
}
const DfaNode& n = dfa[dfa_index];
for (auto&& edge : n.edges)
{
if (edge.first(Peek()))
{
transition_found = true;
dfa_index = edge.second;
break;
}
}
if (transition_found)
{
Step();
}
} while (transition_found);
return dfa[dfa_index].match_failure_reason;
}
void Lexer::SkipWs() const
{
ScanStart();
for (char c = Peek(); std::isspace(c) && c != '\n'; c = Step().Peek())
{
}
if (Peek() == '#')
{
while (Peek() != '\n' && Peek() != '\0')
{
Step();
}
}
ScanFinish();
}
void Lexer::FeedbackTokens() const
{
if (m_lexed_tokens.empty())
{
return;
}
m_pos = m_scan_pos = TagOf(m_lexed_tokens.front());
m_lexed_tokens.clear();
}
bool Lexer::IdentifierHeadExtra(char h) const
{
switch (m_match_rule)
{
case IdentifierMatchRule::Typical:
case IdentifierMatchRule::Mnemonic:
return false;
case IdentifierMatchRule::Directive:
return std::isdigit(h);
}
return false;
}
bool Lexer::IdentifierExtra(char c) const
{
switch (m_match_rule)
{
case IdentifierMatchRule::Typical:
case IdentifierMatchRule::Directive:
return false;
case IdentifierMatchRule::Mnemonic:
return c == '+' || c == '-' || c == '.';
}
return false;
}
void Lexer::ScanStart() const
{
m_scan_pos = m_pos;
}
void Lexer::ScanFinish() const
{
m_pos = m_scan_pos;
}
std::string_view Lexer::ScanFinishOut() const
{
const size_t start = m_pos.index;
m_pos = m_scan_pos;
return m_lex_string.substr(start, m_scan_pos.index - start);
}
char Lexer::Peek() const
{
if (m_scan_pos.index >= m_lex_string.length())
{
return 0;
}
return m_lex_string[m_scan_pos.index];
}
const Lexer& Lexer::Step() const
{
if (m_scan_pos.index >= m_lex_string.length())
{
return *this;
}
if (Peek() == '\n')
{
m_scan_pos.line++;
m_scan_pos.col = 0;
}
else
{
m_scan_pos.col++;
}
m_scan_pos.index++;
return *this;
}
TokenType Lexer::LexStringLit(std::string_view& invalid_reason, Interval& invalid_region) const
{
// The open quote has alread been matched
const size_t string_start = m_scan_pos.index - 1;
TokenType token_type = TokenType::StringLit;
std::optional<std::string_view> failure_reason = RunDfa(string_dfa);
if (failure_reason)
{
token_type = TokenType::Invalid;
invalid_reason = *failure_reason;
invalid_region = Interval{0, m_scan_pos.index - string_start};
}
return token_type;
}
TokenType Lexer::ClassifyAlnum() const
{
const std::string_view alnum = m_lex_string.substr(m_pos.index, m_scan_pos.index - m_pos.index);
constexpr auto valid_regnum = [](std::string_view rn) {
if (rn.length() == 1 && std::isdigit(rn[0]))
{
return true;
}
else if (rn.length() == 2 && std::isdigit(rn[0]) && std::isdigit(rn[1]))
{
if (rn[0] == '1' || rn[0] == '2')
{
return true;
}
if (rn[0] == '3')
{
return rn[1] < '2';
}
}
return false;
};
if (std::tolower(alnum[0]) == 'r' && valid_regnum(alnum.substr(1)))
{
return TokenType::GPR;
}
else if ((CaseInsensitiveEquals(alnum, "sp")) || (CaseInsensitiveEquals(alnum, "rtoc")))
{
return TokenType::GPR;
}
else if (std::tolower(alnum[0]) == 'f' && valid_regnum(alnum.substr(1)))
{
return TokenType::FPR;
}
else if (alnum.length() == 3 && CaseInsensitiveEquals(alnum.substr(0, 2), "cr") &&
alnum[2] >= '0' && alnum[2] <= '7')
{
return TokenType::CRField;
}
else if (CaseInsensitiveEquals(alnum, "lt"))
{
return TokenType::Lt;
}
else if (CaseInsensitiveEquals(alnum, "gt"))
{
return TokenType::Gt;
}
else if (CaseInsensitiveEquals(alnum, "eq"))
{
return TokenType::Eq;
}
else if (CaseInsensitiveEquals(alnum, "so"))
{
return TokenType::So;
}
else if (sprg_map.Find(alnum) != nullptr)
{
return TokenType::SPR;
}
return TokenType::Identifier;
}
AssemblerToken Lexer::LexSingle() const
{
SkipWs();
ScanStart();
const char h = Peek();
TokenType token_type;
std::string_view invalid_reason = "";
Interval invalid_region = Interval{0, 0};
Step();
if (std::isalpha(h) || h == '_' || IdentifierHeadExtra(h))
{
for (char c = Peek(); std::isalnum(c) || c == '_' || IdentifierExtra(c); c = Step().Peek())
{
}
token_type = ClassifyAlnum();
}
else if (h == '"')
{
token_type = LexStringLit(invalid_reason, invalid_region);
}
else if (h == '0')
{
const char imm_type = Peek();
if (imm_type == 'x')
{
token_type = TokenType::HexadecimalLit;
Step();
for (char c = Peek(); std::isxdigit(c); c = Step().Peek())
{
}
}
else if (imm_type == 'b')
{
token_type = TokenType::BinaryLit;
Step();
for (char c = Peek(); IsBinary(c); c = Step().Peek())
{
}
}
else if (IsOctal(imm_type))
{
token_type = TokenType::OctalLit;
for (char c = Peek(); IsOctal(c); c = Step().Peek())
{
}
}
else
{
token_type = TokenType::DecimalLit;
}
}
else if (std::isdigit(h))
{
for (char c = Peek(); std::isdigit(c); c = Step().Peek())
{
}
token_type = TokenType::DecimalLit;
}
else if (h == '<' || h == '>')
{
// Special case for two-character operators
const char second_ch = Peek();
if (second_ch == h)
{
Step();
token_type = second_ch == '<' ? TokenType::Lsh : TokenType::Rsh;
}
else
{
token_type = TokenType::Invalid;
invalid_reason = "Unrecognized character";
invalid_region = Interval{0, 1};
}
}
else
{
token_type = SingleCharToken(h);
if (token_type == TokenType::Invalid)
{
invalid_reason = "Unrecognized character";
invalid_region = Interval{0, 1};
}
}
AssemblerToken new_tok = {token_type, ScanFinishOut(), invalid_reason, invalid_region};
SkipWs();
return new_tok;
}
} // namespace Common::GekkoAssembler::detail