Implement hw accelerated AES

This commit is contained in:
Shawn Hoffman
2022-07-27 01:51:19 -07:00
parent fb45ed3981
commit 46ad8b9d68
13 changed files with 488 additions and 93 deletions

View File

@ -171,14 +171,13 @@ std::vector<u8> NANDImporter::GetEntryData(const NANDFSTEntry& entry)
std::vector<u8> data{};
data.reserve(remaining_bytes);
auto block = std::make_unique<u8[]>(NAND_FAT_BLOCK_SIZE);
while (remaining_bytes > 0)
{
std::array<u8, 16> iv{};
std::vector<u8> block = Common::AES::Decrypt(
m_aes_key.data(), iv.data(), &m_nand[NAND_FAT_BLOCK_SIZE * sub], NAND_FAT_BLOCK_SIZE);
m_aes_ctx->CryptIvZero(&m_nand[NAND_FAT_BLOCK_SIZE * sub], block.get(), NAND_FAT_BLOCK_SIZE);
size_t size = std::min(remaining_bytes, block.size());
data.insert(data.end(), block.begin(), block.begin() + size);
size_t size = std::min(remaining_bytes, NAND_FAT_BLOCK_SIZE);
data.insert(data.end(), block.get(), block.get() + size);
remaining_bytes -= size;
sub = m_superblock->fat[sub];
@ -260,7 +259,7 @@ void NANDImporter::ExportKeys()
{
constexpr size_t NAND_AES_KEY_OFFSET = 0x158;
std::copy_n(&m_nand_keys[NAND_AES_KEY_OFFSET], m_aes_key.size(), m_aes_key.begin());
m_aes_ctx = Common::AES::CreateContextDecrypt(&m_nand_keys[NAND_AES_KEY_OFFSET]);
const std::string file_path = m_nand_root + "/keys.bin";
File::IOFile file(file_path, "wb");

View File

@ -12,6 +12,7 @@
#include <fmt/format.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Swap.h"
namespace DiscIO
@ -74,7 +75,7 @@ private:
std::string m_nand_root;
std::vector<u8> m_nand;
std::vector<u8> m_nand_keys;
std::array<u8, 16> m_aes_key;
std::unique_ptr<Common::AES::Context> m_aes_ctx;
std::unique_ptr<NANDSuperblock> m_superblock;
std::function<void()> m_update_callback;
};

View File

@ -13,11 +13,10 @@
#include <utility>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/Align.h"
#include "Common/Assert.h"
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Logging/Log.h"
#include "Common/MsgHandler.h"
@ -159,17 +158,14 @@ bool VolumeWAD::CheckContentIntegrity(const IOS::ES::Content& content,
if (encrypted_data.size() != Common::AlignUp(content.size, 0x40))
return false;
mbedtls_aes_context context;
const std::array<u8, 16> key = ticket.GetTitleKey();
mbedtls_aes_setkey_dec(&context, key.data(), 128);
auto context = Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data());
std::array<u8, 16> iv{};
iv[0] = static_cast<u8>(content.index >> 8);
iv[1] = static_cast<u8>(content.index & 0xFF);
std::vector<u8> decrypted_data(encrypted_data.size());
mbedtls_aes_crypt_cbc(&context, MBEDTLS_AES_DECRYPT, decrypted_data.size(), iv.data(),
encrypted_data.data(), decrypted_data.data());
context->Crypt(iv.data(), encrypted_data.data(), decrypted_data.data(), decrypted_data.size());
return Common::SHA1::CalculateDigest(decrypted_data.data(), content.size) == content.sha1;
}

View File

@ -16,11 +16,10 @@
#include <utility>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/Align.h"
#include "Common/Assert.h"
#include "Common/CommonTypes.h"
#include "Common/Crypto/AES.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Logging/Log.h"
#include "Common/Swap.h"
@ -128,14 +127,11 @@ VolumeWii::VolumeWii(std::unique_ptr<BlobReader> reader)
return h3_table;
};
auto get_key = [this, partition]() -> std::unique_ptr<mbedtls_aes_context> {
auto get_key = [this, partition]() -> std::unique_ptr<Common::AES::Context> {
const IOS::ES::TicketReader& ticket = *m_partitions[partition].ticket;
if (!ticket.IsValid())
return nullptr;
const std::array<u8, AES_KEY_SIZE> key = ticket.GetTitleKey();
std::unique_ptr<mbedtls_aes_context> aes_context = std::make_unique<mbedtls_aes_context>();
mbedtls_aes_setkey_dec(aes_context.get(), key.data(), 128);
return aes_context;
return Common::AES::CreateContextDecrypt(ticket.GetTitleKey().data());
};
auto get_file_system = [this, partition]() -> std::unique_ptr<FileSystem> {
@ -148,7 +144,7 @@ VolumeWii::VolumeWii(std::unique_ptr<BlobReader> reader)
};
m_partitions.emplace(
partition, PartitionDetails{Common::Lazy<std::unique_ptr<mbedtls_aes_context>>(get_key),
partition, PartitionDetails{Common::Lazy<std::unique_ptr<Common::AES::Context>>(get_key),
Common::Lazy<IOS::ES::TicketReader>(get_ticket),
Common::Lazy<IOS::ES::TMDReader>(get_tmd),
Common::Lazy<std::vector<u8>>(get_cert_chain),
@ -183,11 +179,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit
buffer);
}
mbedtls_aes_context* aes_context = partition_details.key->get();
auto aes_context = partition_details.key->get();
if (!aes_context)
return false;
std::vector<u8> read_buffer(BLOCK_TOTAL_SIZE);
auto read_buffer = std::make_unique<u8[]>(BLOCK_TOTAL_SIZE);
while (length > 0)
{
// Calculate offsets
@ -198,11 +194,11 @@ bool VolumeWii::Read(u64 offset, u64 length, u8* buffer, const Partition& partit
if (m_last_decrypted_block != block_offset_on_disc)
{
// Read the current block
if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.data()))
if (!m_reader->Read(block_offset_on_disc, BLOCK_TOTAL_SIZE, read_buffer.get()))
return false;
// Decrypt the block's data
DecryptBlockData(read_buffer.data(), m_last_decrypted_block_data, aes_context);
DecryptBlockData(read_buffer.get(), m_last_decrypted_block_data, aes_context);
m_last_decrypted_block = block_offset_on_disc;
}
@ -421,19 +417,19 @@ bool VolumeWii::CheckBlockIntegrity(u64 block_index, const u8* encrypted_data,
partition_details.h3_table->size())
return false;
mbedtls_aes_context* aes_context = partition_details.key->get();
auto aes_context = partition_details.key->get();
if (!aes_context)
return false;
HashBlock hashes;
DecryptBlockHashes(encrypted_data, &hashes, aes_context);
u8 cluster_data[BLOCK_DATA_SIZE];
DecryptBlockData(encrypted_data, cluster_data, aes_context);
auto cluster_data = std::make_unique<u8[]>(BLOCK_DATA_SIZE);
DecryptBlockData(encrypted_data, cluster_data.get(), aes_context);
for (u32 hash_index = 0; hash_index < 31; ++hash_index)
{
if (Common::SHA1::CalculateDigest(cluster_data + hash_index * 0x400, 0x400) !=
if (Common::SHA1::CalculateDigest(&cluster_data[hash_index * 0x400], 0x400) !=
hashes.h0[hash_index])
return false;
}
@ -577,8 +573,7 @@ bool VolumeWii::EncryptGroup(
std::vector<std::future<void>> encryption_futures(threads);
mbedtls_aes_context aes_context;
mbedtls_aes_setkey_enc(&aes_context, key.data(), 128);
auto aes_context = Common::AES::CreateContextEncrypt(key.data());
for (size_t i = 0; i < threads; ++i)
{
@ -589,13 +584,11 @@ bool VolumeWii::EncryptGroup(
{
u8* out_ptr = out->data() + j * BLOCK_TOTAL_SIZE;
u8 iv[16] = {};
mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_HEADER_SIZE, iv,
reinterpret_cast<u8*>(&unencrypted_hashes[j]), out_ptr);
aes_context->CryptIvZero(reinterpret_cast<u8*>(&unencrypted_hashes[j]), out_ptr,
BLOCK_HEADER_SIZE);
std::memcpy(iv, out_ptr + 0x3D0, sizeof(iv));
mbedtls_aes_crypt_cbc(&aes_context, MBEDTLS_AES_ENCRYPT, BLOCK_DATA_SIZE, iv,
unencrypted_data[j].data(), out_ptr + BLOCK_HEADER_SIZE);
aes_context->Crypt(out_ptr + 0x3D0, unencrypted_data[j].data(),
out_ptr + BLOCK_HEADER_SIZE, BLOCK_DATA_SIZE);
}
},
i * BLOCKS_PER_GROUP / threads, (i + 1) * BLOCKS_PER_GROUP / threads);
@ -607,20 +600,14 @@ bool VolumeWii::EncryptGroup(
return true;
}
void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context)
void VolumeWii::DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context)
{
std::array<u8, 16> iv;
iv.fill(0);
mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, sizeof(HashBlock), iv.data(), in,
reinterpret_cast<u8*>(out));
aes_context->CryptIvZero(in, reinterpret_cast<u8*>(out), sizeof(HashBlock));
}
void VolumeWii::DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context)
void VolumeWii::DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context)
{
std::array<u8, 16> iv;
std::copy(&in[0x3d0], &in[0x3e0], iv.data());
mbedtls_aes_crypt_cbc(aes_context, MBEDTLS_AES_DECRYPT, BLOCK_DATA_SIZE, iv.data(),
&in[BLOCK_HEADER_SIZE], out);
aes_context->Crypt(&in[0x3d0], &in[sizeof(HashBlock)], out, BLOCK_DATA_SIZE);
}
} // namespace DiscIO

View File

@ -11,8 +11,6 @@
#include <string>
#include <vector>
#include <mbedtls/aes.h>
#include "Common/CommonTypes.h"
#include "Common/Crypto/SHA1.h"
#include "Common/Lazy.h"
@ -21,6 +19,8 @@
#include "DiscIO/Volume.h"
#include "DiscIO/VolumeDisc.h"
#include "Common/Crypto/AES.h"
namespace DiscIO
{
class BlobReader;
@ -34,7 +34,7 @@ enum class Platform;
class VolumeWii : public VolumeDisc
{
public:
static constexpr size_t AES_KEY_SIZE = 16;
static constexpr size_t AES_KEY_SIZE = Common::AES::Context::KEY_SIZE;
static constexpr u32 BLOCKS_PER_GROUP = 0x40;
@ -106,8 +106,8 @@ public:
const std::function<void(HashBlock hash_blocks[BLOCKS_PER_GROUP])>&
hash_exception_callback = {});
static void DecryptBlockHashes(const u8* in, HashBlock* out, mbedtls_aes_context* aes_context);
static void DecryptBlockData(const u8* in, u8* out, mbedtls_aes_context* aes_context);
static void DecryptBlockHashes(const u8* in, HashBlock* out, Common::AES::Context* aes_context);
static void DecryptBlockData(const u8* in, u8* out, Common::AES::Context* aes_context);
protected:
u32 GetOffsetShift() const override { return 2; }
@ -115,7 +115,7 @@ protected:
private:
struct PartitionDetails
{
Common::Lazy<std::unique_ptr<mbedtls_aes_context>> key;
Common::Lazy<std::unique_ptr<Common::AES::Context>> key;
Common::Lazy<IOS::ES::TicketReader> ticket;
Common::Lazy<IOS::ES::TMDReader> tmd;
Common::Lazy<std::vector<u8>> cert_chain;

View File

@ -1318,8 +1318,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
{
const PartitionEntry& partition_entry = partition_entries[parameters.data_entry->index];
mbedtls_aes_context aes_context;
mbedtls_aes_setkey_dec(&aes_context, partition_entry.partition_key.data(), 128);
auto aes_context = Common::AES::CreateContextDecrypt(partition_entry.partition_key.data());
const u64 groups = Common::AlignUp(parameters.data.size(), VolumeWii::GROUP_TOTAL_SIZE) /
VolumeWii::GROUP_TOTAL_SIZE;
@ -1388,7 +1387,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
{
const u64 offset_of_block = offset_of_group + j * VolumeWii::BLOCK_TOTAL_SIZE;
VolumeWii::DecryptBlockData(parameters.data.data() + offset_of_block,
state->decryption_buffer[j].data(), &aes_context);
state->decryption_buffer[j].data(), aes_context.get());
}
else
{
@ -1413,7 +1412,7 @@ WIARVZFileReader<RVZ>::ProcessAndCompress(CompressThreadState* state, CompressPa
VolumeWii::HashBlock hashes;
VolumeWii::DecryptBlockHashes(parameters.data.data() + offset_of_block, &hashes,
&aes_context);
aes_context.get());
const auto compare_hash = [&](size_t offset_in_block) {
ASSERT(offset_in_block + Common::SHA1::DIGEST_LEN <= VolumeWii::BLOCK_HEADER_SIZE);