Merge pull request #13692 from jordan-woyak/work-queue-thread-command-queue

WorkQueueThread: Fix Cancel() race with internal command queue.
This commit is contained in:
JosJuice
2025-05-24 16:26:10 +02:00
committed by GitHub

View File

@ -3,7 +3,6 @@
#pragma once #pragma once
#include <atomic>
#include <functional> #include <functional>
#include <future> #include <future>
#include <mutex> #include <mutex>
@ -22,8 +21,10 @@ template <typename T, bool IsSingleProducer>
class WorkQueueThreadBase final class WorkQueueThreadBase final
{ {
public: public:
using FunctionType = std::function<void(T)>;
WorkQueueThreadBase() = default; WorkQueueThreadBase() = default;
WorkQueueThreadBase(std::string name, std::function<void(T)> function) WorkQueueThreadBase(std::string name, FunctionType function)
{ {
Reset(std::move(name), std::move(function)); Reset(std::move(name), std::move(function));
} }
@ -31,11 +32,10 @@ public:
// Shuts the current work thread down (if any) and starts a new thread with the given function // Shuts the current work thread down (if any) and starts a new thread with the given function
// Note: Some consumers of this API push items to the queue before starting the thread. // Note: Some consumers of this API push items to the queue before starting the thread.
void Reset(std::string name, std::function<void(T)> function) void Reset(std::string name, FunctionType function)
{ {
auto lg = GetLockGuard(); auto lg = GetLockGuard();
Shutdown(); Shutdown();
m_run_thread.store(true, std::memory_order_relaxed);
m_thread = std::thread(std::bind_front(&WorkQueueThreadBase::ThreadLoop, this), std::move(name), m_thread = std::thread(std::bind_front(&WorkQueueThreadBase::ThreadLoop, this), std::move(name),
std::move(function)); std::move(function));
} }
@ -56,25 +56,30 @@ public:
void Cancel() void Cancel()
{ {
auto lg = GetLockGuard(); auto lg = GetLockGuard();
if (IsRunning())
{ // Fast path avoids round trip thread communication and saves ~20us.
m_skip_work.store(true, std::memory_order_relaxed); if (m_items.Empty())
WaitForCompletion(); return;
m_skip_work.store(false, std::memory_order_relaxed);
} RunCommand([&] { m_items.Clear(); });
else
{
m_items.Clear();
}
} }
// Tells the worker thread to stop when its queue is empty. // Tells the worker thread to stop when its queue is empty.
// Blocks until the worker thread exits. Does nothing if thread isn't running. // Blocks until the worker thread exits. Does nothing if thread isn't running.
void Shutdown() { StopThread(true); } void Shutdown()
{
auto lg = GetLockGuard();
WaitForCompletion();
StopThread();
}
// Tells the worker thread to stop immediately, potentially leaving work in the queue. // Tells the worker thread to stop immediately, potentially leaving work in the queue.
// Blocks until the worker thread exits. Does nothing if thread isn't running. // Blocks until the worker thread exits. Does nothing if thread isn't running.
void Stop() { StopThread(false); } void Stop()
{
auto lg = GetLockGuard();
StopThread();
}
// Stops the worker thread ASAP and empties the queue. // Stops the worker thread ASAP and empties the queue.
void StopAndCancel() void StopAndCancel()
@ -94,18 +99,33 @@ public:
} }
private: private:
void StopThread(bool wait_for_completion) using CommandFunction = std::function<void()>;
// Blocking.
void RunCommand(CommandFunction cmd)
{ {
auto lg = GetLockGuard(); if (!IsRunning())
if (wait_for_completion)
WaitForCompletion();
if (m_run_thread.exchange(false, std::memory_order_relaxed))
{ {
m_event.Set(); std::invoke(cmd);
m_thread.join(); return;
} }
m_commands.Emplace(std::move(cmd));
m_event.Set();
m_commands.WaitForEmpty();
}
// Stop immediately.
void StopThread()
{
if (!m_thread.joinable())
return;
// empty-function shutdown signal.
m_commands.Emplace(CommandFunction{});
m_event.Set();
m_thread.join();
m_commands.Clear();
} }
auto GetLockGuard() auto GetLockGuard()
@ -124,24 +144,29 @@ private:
bool IsRunning() { return m_thread.joinable(); } bool IsRunning() { return m_thread.joinable(); }
void ThreadLoop(const std::string& thread_name, const std::function<void(T)>& function) void ThreadLoop(const std::string& thread_name, const FunctionType& function)
{ {
Common::SetCurrentThreadName(thread_name.c_str()); Common::SetCurrentThreadName(thread_name.c_str());
while (m_run_thread.load(std::memory_order_relaxed)) while (true)
{ {
while (!m_commands.Empty())
{
CommandFunction& command = m_commands.Front();
// empty-function shutdown signal.
if (!command)
return;
std::invoke(command);
m_commands.Pop();
}
if (m_items.Empty()) if (m_items.Empty())
{ {
m_event.Wait(); m_event.Wait();
continue; continue;
} }
if (m_skip_work.load(std::memory_order_relaxed))
{
m_items.Clear();
continue;
}
function(std::move(m_items.Front())); function(std::move(m_items.Front()));
m_items.Pop(); m_items.Pop();
} }
@ -149,9 +174,8 @@ private:
std::thread m_thread; std::thread m_thread;
Common::WaitableSPSCQueue<T> m_items; Common::WaitableSPSCQueue<T> m_items;
Common::WaitableSPSCQueue<CommandFunction> m_commands;
Common::Event m_event; Common::Event m_event;
std::atomic_bool m_skip_work = false;
std::atomic_bool m_run_thread = false;
using DummyMutex = std::type_identity<void>; using DummyMutex = std::type_identity<void>;
using ProducerMutex = std::conditional_t<IsSingleProducer, DummyMutex, std::recursive_mutex>; using ProducerMutex = std::conditional_t<IsSingleProducer, DummyMutex, std::recursive_mutex>;