Merge pull request #11881 from JosJuice/aarch64-function-call

JitArm64: Add utility for calling a function with arguments
This commit is contained in:
Admiral H. Curtiss
2023-11-25 17:30:42 +01:00
committed by GitHub
12 changed files with 411 additions and 99 deletions

View File

@ -1795,6 +1795,62 @@ void ARM64XEmitter::ADRP(ARM64Reg Rd, s64 imm)
EncodeAddressInst(1, Rd, static_cast<s32>(imm >> 12));
}
// This is using a hand-rolled algorithm. The goal is zero memory allocations, not necessarily
// the best JIT-time time complexity. (The number of moves is usually very small.)
void ARM64XEmitter::ParallelMoves(RegisterMove* begin, RegisterMove* end,
std::array<u8, 32>* source_gpr_usages)
{
// X0-X7 are used for passing arguments.
// X18-X31 are either callee saved or used for special purposes.
constexpr size_t temp_reg_begin = 8;
constexpr size_t temp_reg_end = 18;
while (begin != end)
{
bool removed_moves_during_this_loop_iteration = false;
RegisterMove* move = end;
while (move != begin)
{
RegisterMove* prev_move = move;
--move;
if ((*source_gpr_usages)[DecodeReg(move->dst)] == 0)
{
MOV(move->dst, move->src);
(*source_gpr_usages)[DecodeReg(move->src)]--;
std::move(prev_move, end, move);
--end;
removed_moves_during_this_loop_iteration = true;
}
}
if (!removed_moves_during_this_loop_iteration)
{
// We need to break a cycle using a temporary register.
size_t temp_reg = temp_reg_begin;
while ((*source_gpr_usages)[temp_reg] != 0)
{
++temp_reg;
ASSERT_MSG(COMMON, temp_reg != temp_reg_end, "Out of registers");
}
const ARM64Reg src = begin->src;
const ARM64Reg dst =
(Is64Bit(src) ? EncodeRegTo64 : EncodeRegTo32)(static_cast<ARM64Reg>(temp_reg));
MOV(dst, src);
(*source_gpr_usages)[DecodeReg(dst)] = (*source_gpr_usages)[DecodeReg(src)];
(*source_gpr_usages)[DecodeReg(src)] = 0;
std::for_each(begin, end, [src, dst](RegisterMove& move) {
if (move.src == src)
move.src = dst;
});
}
}
}
template <typename T>
void ARM64XEmitter::MOVI2RImpl(ARM64Reg Rd, T imm)
{

View File

@ -3,10 +3,12 @@
#pragma once
#include <array>
#include <bit>
#include <cstring>
#include <functional>
#include <optional>
#include <type_traits>
#include <utility>
#include "Common/ArmCommon.h"
@ -17,6 +19,7 @@
#include "Common/Common.h"
#include "Common/CommonTypes.h"
#include "Common/MathUtil.h"
#include "Common/SmallVector.h"
namespace Arm64Gen
{
@ -599,6 +602,12 @@ class ARM64XEmitter
friend class ARM64FloatEmitter;
private:
struct RegisterMove
{
ARM64Reg dst;
ARM64Reg src;
};
// Pointer to memory where code will be emitted to.
u8* m_code = nullptr;
@ -646,6 +655,10 @@ private:
[[nodiscard]] FixupBranch WriteFixupBranch();
// This function solves the "parallel moves" problem common in compilers.
// The arguments are mutated!
void ParallelMoves(RegisterMove* begin, RegisterMove* end, std::array<u8, 32>* source_gpr_usages);
template <typename T>
void MOVI2RImpl(ARM64Reg Rd, T imm);
@ -1058,6 +1071,114 @@ public:
void ABI_PushRegisters(BitSet32 registers);
void ABI_PopRegisters(BitSet32 registers, BitSet32 ignore_mask = BitSet32(0));
// Plain function call
void QuickCallFunction(ARM64Reg scratchreg, const void* func);
template <typename T>
void QuickCallFunction(ARM64Reg scratchreg, T func)
{
QuickCallFunction(scratchreg, (const void*)func);
}
template <typename FuncRet, typename... FuncArgs, typename... Args>
void ABI_CallFunction(FuncRet (*func)(FuncArgs...), Args... args)
{
static_assert(sizeof...(FuncArgs) == sizeof...(Args), "Wrong number of arguments");
static_assert(sizeof...(FuncArgs) <= 8, "Passing arguments on the stack is not supported");
if constexpr (!std::is_void_v<FuncRet>)
static_assert(sizeof(FuncRet) <= 16, "Large return types are not supported");
std::array<u8, 32> source_gpr_uses{};
auto check_argument = [&](auto& arg) {
using Arg = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<Arg, ARM64Reg>)
{
ASSERT(IsGPR(arg));
source_gpr_uses[DecodeReg(arg)]++;
}
else
{
// To be more correct, we should be checking FuncArgs here rather than Args, but that's a
// lot more effort to implement. Let's just do these best-effort checks for now.
static_assert(!std::is_floating_point_v<Arg>, "Floating-point arguments are not supported");
static_assert(sizeof(Arg) <= 8, "Arguments bigger than a register are not supported");
}
};
(check_argument(args), ...);
{
Common::SmallVector<RegisterMove, sizeof...(Args)> pending_moves;
size_t i = 0;
auto handle_register_argument = [&](auto& arg) {
using Arg = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<Arg, ARM64Reg>)
{
const ARM64Reg dst_reg =
(Is64Bit(arg) ? EncodeRegTo64 : EncodeRegTo32)(static_cast<ARM64Reg>(i));
if (dst_reg == arg)
{
// The value is already in the right register.
source_gpr_uses[DecodeReg(arg)]--;
}
else if (source_gpr_uses[i] == 0)
{
// The destination register isn't used as the source of another move.
// We can go ahead and do the move right away.
MOV(dst_reg, arg);
source_gpr_uses[DecodeReg(arg)]--;
}
else
{
// The destination register is used as the source of a move we haven't gotten to yet.
// Let's record that we need to deal with this move later.
pending_moves.emplace_back(dst_reg, arg);
}
}
++i;
};
(handle_register_argument(args), ...);
if (!pending_moves.empty())
{
ParallelMoves(pending_moves.data(), pending_moves.data() + pending_moves.size(),
&source_gpr_uses);
}
}
{
size_t i = 0;
auto handle_immediate_argument = [&](auto& arg) {
using Arg = std::decay_t<decltype(arg)>;
if constexpr (!std::is_same_v<Arg, ARM64Reg>)
{
const ARM64Reg dst_reg =
(sizeof(arg) == 8 ? EncodeRegTo64 : EncodeRegTo32)(static_cast<ARM64Reg>(i));
if constexpr (std::is_pointer_v<Arg>)
MOVP2R(dst_reg, arg);
else
MOVI2R(dst_reg, arg);
}
++i;
};
(handle_immediate_argument(args), ...);
}
QuickCallFunction(ARM64Reg::X8, func);
}
// Utility to generate a call to a std::function object.
//
// Unfortunately, calling operator() directly is undefined behavior in C++
@ -1069,23 +1190,11 @@ public:
return (*f)(args...);
}
// This function expects you to have set up the state.
// Overwrites X0 and X8
template <typename T, typename... Args>
ARM64Reg ABI_SetupLambda(const std::function<T(Args...)>* f)
template <typename FuncRet, typename... FuncArgs, typename... Args>
void ABI_CallLambdaFunction(const std::function<FuncRet(FuncArgs...)>* f, Args... args)
{
auto trampoline = &ARM64XEmitter::CallLambdaTrampoline<T, Args...>;
MOVP2R(ARM64Reg::X8, trampoline);
MOVP2R(ARM64Reg::X0, const_cast<void*>((const void*)f));
return ARM64Reg::X8;
}
// Plain function call
void QuickCallFunction(ARM64Reg scratchreg, const void* func);
template <typename T>
void QuickCallFunction(ARM64Reg scratchreg, T func)
{
QuickCallFunction(scratchreg, (const void*)func);
auto trampoline = &ARM64XEmitter::CallLambdaTrampoline<FuncRet, FuncArgs...>;
ABI_CallFunction(trampoline, f, args...);
}
};

View File

@ -29,9 +29,11 @@ public:
T& operator[](size_t i) { return m_array[i]; }
const T& operator[](size_t i) const { return m_array[i]; }
auto data() { return m_array.data(); }
auto begin() { return m_array.begin(); }
auto end() { return m_array.begin() + m_size; }
auto data() const { return m_array.data(); }
auto begin() const { return m_array.begin(); }
auto end() const { return m_array.begin() + m_size; }