From b8c52edbf43aa7b1a57fbc09c9f1f3a18ebc0bf2 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Wed, 29 Jun 2022 15:46:45 -0300 Subject: [PATCH] Protect against stack overflow caused by deep recursive calls * PPTC version bump * Also reset call depth when not using the unmanaged dispatch loop * Increment call depth on function start rather than before call --- .../Instructions/InstEmitException.cs | 4 +- .../Instructions/InstEmitException32.cs | 2 +- src/ARMeilleure/Instructions/InstEmitFlow.cs | 2 +- .../Instructions/InstEmitFlowHelper.cs | 60 +++++++++++++++++-- src/ARMeilleure/Optimizations.cs | 1 + src/ARMeilleure/State/ExecutionContext.cs | 5 ++ src/ARMeilleure/State/NativeContext.cs | 8 +++ src/ARMeilleure/Translation/PTC/Ptc.cs | 2 +- src/ARMeilleure/Translation/Translator.cs | 2 + .../Translation/TranslatorStubs.cs | 8 +++ 10 files changed, 83 insertions(+), 11 deletions(-) diff --git a/src/ARMeilleure/Instructions/InstEmitException.cs b/src/ARMeilleure/Instructions/InstEmitException.cs index d30fb2fbd..a91716c64 100644 --- a/src/ARMeilleure/Instructions/InstEmitException.cs +++ b/src/ARMeilleure/Instructions/InstEmitException.cs @@ -19,7 +19,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } public static void Svc(ArmEmitterContext context) @@ -49,7 +49,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitException32.cs b/src/ARMeilleure/Instructions/InstEmitException32.cs index 57af1522b..e5bad56ef 100644 --- a/src/ARMeilleure/Instructions/InstEmitException32.cs +++ b/src/ARMeilleure/Instructions/InstEmitException32.cs @@ -33,7 +33,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(context.CurrOp.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(context.CurrOp.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitFlow.cs b/src/ARMeilleure/Instructions/InstEmitFlow.cs index a986bf66f..cb214d3d5 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlow.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlow.cs @@ -66,7 +66,7 @@ namespace ARMeilleure.Instructions { OpCodeBReg op = (OpCodeBReg)context.CurrOp; - context.Return(GetIntOrZR(context, op.Rn)); + EmitReturn(context, GetIntOrZR(context, op.Rn)); } public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true); diff --git a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs index d0871b29f..74866f982 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs @@ -13,6 +13,10 @@ namespace ARMeilleure.Instructions { static class InstEmitFlowHelper { + // How many calls we can have in our call stack before we give up and return to the dispatcher. + // This prevents stack overflows caused by deep recursive calls. + private const int MaxCallDepth = 200; + public static void EmitCondBranch(ArmEmitterContext context, Operand target, Condition cond) { if (cond != Condition.Al) @@ -182,12 +186,7 @@ namespace ARMeilleure.Instructions { if (isReturn || context.IsSingleStep) { - if (target.Type == OperandType.I32) - { - target = context.ZeroExtend32(OperandType.I64, target); - } - - context.Return(target); + EmitReturn(context, target); } else { @@ -195,6 +194,19 @@ namespace ARMeilleure.Instructions } } + public static void EmitReturn(ArmEmitterContext context, Operand target) + { + Operand nativeContext = context.LoadArgument(OperandType.I64, 0); + DecreaseCallDepth(context, nativeContext); + + if (target.Type == OperandType.I32) + { + target = context.ZeroExtend32(OperandType.I64, target); + } + + context.Return(target); + } + private static void EmitTableBranch(ArmEmitterContext context, Operand guestAddress, bool isJump) { context.StoreToContext(); @@ -257,6 +269,8 @@ namespace ARMeilleure.Instructions if (isJump) { + DecreaseCallDepth(context, nativeContext); + context.Tailcall(hostAddress, nativeContext); } else @@ -278,8 +292,42 @@ namespace ARMeilleure.Instructions Operand lblContinue = context.GetLabel(nextAddr.Value); context.BranchIf(lblContinue, returnAddress, nextAddr, Comparison.Equal, BasicBlockFrequency.Cold); + DecreaseCallDepth(context, nativeContext); + context.Return(returnAddress); } } + + public static void EmitCallDepthCheckAndIncrement(EmitterContext context, Operand guestAddress) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand nativeContext = context.LoadArgument(OperandType.I64, 0); + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + Operand lblDoCall = Label(); + + context.BranchIf(lblDoCall, currentCallDepth, Const(MaxCallDepth), Comparison.LessUI); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + context.Return(guestAddress); + + context.MarkLabel(lblDoCall); + context.Store(callDepthAddr, context.Add(currentCallDepth, Const(1))); + } + + private static void DecreaseCallDepth(EmitterContext context, Operand nativeContext) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + } } } diff --git a/src/ARMeilleure/Optimizations.cs b/src/ARMeilleure/Optimizations.cs index 6dd7befe7..3a76b6d93 100644 --- a/src/ARMeilleure/Optimizations.cs +++ b/src/ARMeilleure/Optimizations.cs @@ -13,6 +13,7 @@ namespace ARMeilleure public static bool AllowLcqInFunctionTable { get; set; } = true; public static bool UseUnmanagedDispatchLoop { get; set; } = true; public static bool EnableDebugging { get; set; } = false; + public static bool EnableDeepCallRecursionProtection { get; set; } = true; public static bool UseAdvSimdIfAvailable { get; set; } = true; public static bool UseArm64AesIfAvailable { get; set; } = true; diff --git a/src/ARMeilleure/State/ExecutionContext.cs b/src/ARMeilleure/State/ExecutionContext.cs index fa1a4a032..365805e45 100644 --- a/src/ARMeilleure/State/ExecutionContext.cs +++ b/src/ARMeilleure/State/ExecutionContext.cs @@ -134,6 +134,11 @@ namespace ARMeilleure.State public bool GetFPstateFlag(FPState flag) => _nativeContext.GetFPStateFlag(flag); public void SetFPstateFlag(FPState flag, bool value) => _nativeContext.SetFPStateFlag(flag, value); + internal void ResetCallDepth() + { + _nativeContext.ResetCallDepth(); + } + internal void CheckInterrupt() { if (Interrupted) diff --git a/src/ARMeilleure/State/NativeContext.cs b/src/ARMeilleure/State/NativeContext.cs index a9f1c3dab..bc9e31888 100644 --- a/src/ARMeilleure/State/NativeContext.cs +++ b/src/ARMeilleure/State/NativeContext.cs @@ -22,6 +22,7 @@ namespace ARMeilleure.State public ulong ExclusiveValueHigh; public int Running; public long Tpidr2El0; + public int CallDepth; /// /// Precise PC value used for debugging. @@ -199,6 +200,8 @@ namespace ARMeilleure.State public bool GetRunning() => GetStorage().Running != 0; public void SetRunning(bool value) => GetStorage().Running = value ? 1 : 0; + public void ResetCallDepth() => GetStorage().CallDepth = 0; + public unsafe static int GetRegisterOffset(Register reg) { if (reg.Type == RegisterType.Integer) @@ -284,6 +287,11 @@ namespace ARMeilleure.State return StorageOffset(ref _dummyStorage, ref _dummyStorage.DebugPrecisePc); } + public static int GetCallDepthOffset() + { + return StorageOffset(ref _dummyStorage, ref _dummyStorage.CallDepth); + } + private static int StorageOffset(ref NativeCtxStorage storage, ref T target) { return (int)Unsafe.ByteOffset(ref Unsafe.As(ref storage), ref target); diff --git a/src/ARMeilleure/Translation/PTC/Ptc.cs b/src/ARMeilleure/Translation/PTC/Ptc.cs index c69ebcadb..cfa4377db 100644 --- a/src/ARMeilleure/Translation/PTC/Ptc.cs +++ b/src/ARMeilleure/Translation/PTC/Ptc.cs @@ -33,7 +33,7 @@ namespace ARMeilleure.Translation.PTC private const string OuterHeaderMagicString = "PTCohd\0\0"; private const string InnerHeaderMagicString = "PTCihd\0\0"; - private const uint InternalVersion = 7009; //! To be incremented manually for each change to the ARMeilleure project. + private const uint InternalVersion = 7010; //! To be incremented manually for each change to the ARMeilleure project. private const string ActualDir = "0"; private const string BackupDir = "1"; diff --git a/src/ARMeilleure/Translation/Translator.cs b/src/ARMeilleure/Translation/Translator.cs index 073b7ffe2..458ccaf9b 100644 --- a/src/ARMeilleure/Translation/Translator.cs +++ b/src/ARMeilleure/Translation/Translator.cs @@ -186,6 +186,7 @@ namespace ARMeilleure.Translation Statistics.StartTimer(); + context.ResetCallDepth(); ulong nextAddr = func.Execute(Stubs.ContextWrapper, context); Statistics.StopTimer(address); @@ -260,6 +261,7 @@ namespace ARMeilleure.Translation Logger.StartPass(PassName.Translation); + InstEmitFlowHelper.EmitCallDepthCheckAndIncrement(context, Const(address)); EmitSynchronization(context); if (blocks[0].Address != address) diff --git a/src/ARMeilleure/Translation/TranslatorStubs.cs b/src/ARMeilleure/Translation/TranslatorStubs.cs index 458a42434..2d95ceb99 100644 --- a/src/ARMeilleure/Translation/TranslatorStubs.cs +++ b/src/ARMeilleure/Translation/TranslatorStubs.cs @@ -262,10 +262,18 @@ namespace ARMeilleure.Translation Operand runningAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetRunningOffset())); Operand dispatchAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetDispatchAddressOffset())); + Operand callDepthAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); EmitSyncFpContext(context, nativeContext, true); context.MarkLabel(beginLbl); + + if (Optimizations.EnableDeepCallRecursionProtection) + { + // Reset the call depth counter, since this is our first guest function call. + context.Store(callDepthAddress, Const(0)); + } + context.Store(dispatchAddress, guestAddress); context.Copy(guestAddress, context.Call(Const((ulong)DispatchStub), OperandType.I64, nativeContext)); context.BranchIfFalse(endLbl, guestAddress);