diff --git a/Source/Core/Common/Random.cpp b/Source/Core/Common/Random.cpp index c512e7695d..754cea397f 100644 --- a/Source/Core/Common/Random.cpp +++ b/Source/Core/Common/Random.cpp @@ -11,10 +11,42 @@ namespace Common::Random { -class CSPRNG final +struct PRNG::Impl +{ + Impl(void* seed, std::size_t size) + { + mbedtls_hmac_drbg_init(&m_context); + const int ret = mbedtls_hmac_drbg_seed_buf( + &m_context, mbedtls_md_info_from_type(MBEDTLS_MD_SHA256), static_cast(seed), size); + ASSERT(ret == 0); + } + + ~Impl() { mbedtls_hmac_drbg_free(&m_context); } + + void Generate(void* buffer, std::size_t size) + { + const int ret = mbedtls_hmac_drbg_random(&m_context, static_cast(buffer), size); + ASSERT(ret == 0); + } + + mbedtls_hmac_drbg_context m_context; +}; + +PRNG::PRNG(void* seed, std::size_t size) : m_impl(std::make_unique(seed, size)) +{ +} + +PRNG::~PRNG() = default; + +void PRNG::Generate(void* buffer, std::size_t size) +{ + m_impl->Generate(buffer, size); +} + +class EntropySeededPRNG final { public: - CSPRNG() + EntropySeededPRNG() { mbedtls_entropy_init(&m_entropy); mbedtls_hmac_drbg_init(&m_context); @@ -23,7 +55,7 @@ public: ASSERT(ret == 0); } - ~CSPRNG() + ~EntropySeededPRNG() { mbedtls_hmac_drbg_free(&m_context); mbedtls_entropy_free(&m_entropy); @@ -40,10 +72,10 @@ private: mbedtls_hmac_drbg_context m_context; }; -static thread_local CSPRNG s_csprng; +static thread_local EntropySeededPRNG s_esprng; void Generate(void* buffer, std::size_t size) { - s_csprng.Generate(buffer, size); + s_esprng.Generate(buffer, size); } } // namespace Common::Random diff --git a/Source/Core/Common/Random.h b/Source/Core/Common/Random.h index 1f234f8ac3..789f62d9c1 100644 --- a/Source/Core/Common/Random.h +++ b/Source/Core/Common/Random.h @@ -5,12 +5,37 @@ #pragma once #include +#include #include #include "Common/CommonTypes.h" namespace Common::Random { +/// Cryptographically secure pseudo-random number generator, with explicit seed. +class PRNG final +{ +public: + explicit PRNG(u64 seed) : PRNG(&seed, sizeof(u64)) {} + PRNG(void* seed, std::size_t size); + ~PRNG(); + + void Generate(void* buffer, std::size_t size); + + template + T GenerateValue() + { + static_assert(std::is_arithmetic(), "T must be an arithmetic type in GenerateValue."); + T value; + Generate(&value, sizeof(value)); + return value; + } + +private: + struct Impl; + std::unique_ptr m_impl; +}; + /// Fill `buffer` with random bytes using a cryptographically secure pseudo-random number generator. void Generate(void* buffer, std::size_t size);