diff --git a/Source/Core/DiscIO/WIABlob.cpp b/Source/Core/DiscIO/WIABlob.cpp index 74f7f4293e..6c56af5ad1 100644 --- a/Source/Core/DiscIO/WIABlob.cpp +++ b/Source/Core/DiscIO/WIABlob.cpp @@ -5,8 +5,10 @@ #include "DiscIO/WIABlob.h" #include +#include #include #include +#include #include "Common/Align.h" #include "Common/CommonTypes.h" @@ -71,7 +73,8 @@ bool WIAFileReader::Initialize(const std::string& path) return false; const u32 compression_type = Common::swap32(m_header_2.compression_type); - if (compression_type != 0) + m_compression_type = static_cast(compression_type); + if (m_compression_type > CompressionType::Purge) { ERROR_LOG(DISCIO, "Unsupported WIA compression type %u in %s", compression_type, path.c_str()); return false; @@ -114,26 +117,30 @@ bool WIAFileReader::Initialize(const std::string& path) Common::swap32(b.data_entries[0].first_sector); }); - // TODO: Compression const u32 number_of_raw_data_entries = Common::swap32(m_header_2.number_of_raw_data_entries); m_raw_data_entries.resize(number_of_raw_data_entries); - if (!m_file.Seek(Common::swap64(m_header_2.raw_data_entries_offset), SEEK_SET)) - return false; - if (!m_file.ReadArray(m_raw_data_entries.data(), number_of_raw_data_entries)) + if (!ReadCompressedData(number_of_raw_data_entries * sizeof(RawDataEntry), + Common::swap64(m_header_2.raw_data_entries_offset), + Common::swap32(m_header_2.raw_data_entries_size), + reinterpret_cast(m_raw_data_entries.data()), false)) + { return false; + } std::sort(m_raw_data_entries.begin(), m_raw_data_entries.end(), [](const RawDataEntry& a, const RawDataEntry& b) { return Common::swap64(a.data_offset) < Common::swap64(b.data_offset); }); - // TODO: Compression const u32 number_of_group_entries = Common::swap32(m_header_2.number_of_group_entries); m_group_entries.resize(number_of_group_entries); - if (!m_file.Seek(Common::swap64(m_header_2.group_entries_offset), SEEK_SET)) - return false; - if (!m_file.ReadArray(m_group_entries.data(), number_of_group_entries)) + if (!ReadCompressedData(number_of_group_entries * sizeof(GroupEntry), + Common::swap64(m_header_2.group_entries_offset), + Common::swap32(m_header_2.group_entries_size), + reinterpret_cast(m_group_entries.data()), false)) + { return false; + } return true; } @@ -239,24 +246,13 @@ bool WIAFileReader::ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chu const u64 group_offset = data_offset + i * chunk_size; const u64 offset_in_group = *offset - group_offset; - // TODO: Compression - - u64 group_offset_in_file = static_cast(Common::swap32(group.data_offset)) << 2; - - if (exception_list) - { - u16 exceptions; - if (!m_file.Seek(group_offset_in_file, SEEK_SET) || !m_file.ReadArray(&exceptions, 1)) - return false; - - group_offset_in_file += Common::AlignUp( - sizeof(exceptions) + Common::swap16(exceptions) * sizeof(HashExceptionEntry), 4); - } - - const u64 offset_in_file = group_offset_in_file + offset_in_group; + const u64 group_offset_in_file = static_cast(Common::swap32(group.data_offset)) << 2; const u64 bytes_to_read = std::min(chunk_size - offset_in_group, *size); - if (!m_file.Seek(offset_in_file, SEEK_SET) || !m_file.ReadBytes(*out_ptr, bytes_to_read)) + if (!ReadCompressedData(chunk_size, group_offset_in_file, Common::swap32(group.data_size), + offset_in_group, bytes_to_read, *out_ptr, exception_list)) + { return false; + } *offset += bytes_to_read; *size -= bytes_to_read; @@ -266,6 +262,131 @@ bool WIAFileReader::ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chu return true; } +bool WIAFileReader::ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, + u8* out_ptr, bool exception_list) +{ + switch (m_compression_type) + { + case CompressionType::None: + { + return ReadCompressedData(decompressed_data_size, data_offset, data_size, 0, + decompressed_data_size, out_ptr, exception_list); + } + + case CompressionType::Purge: + { + if (!m_file.Seek(data_offset, SEEK_SET)) + return false; + + if (exception_list) + { + const std::optional exception_size = ReadExceptionListFromFile(); + if (!exception_size) + return false; + + data_size -= *exception_size; + } + + const u64 hash_offset = data_size - sizeof(SHA1); + u32 offset_in_data = 0; + u32 offset_in_decompressed_data = 0; + + while (offset_in_data < hash_offset) + { + PurgeSegment purge_segment; + if (!m_file.ReadArray(&purge_segment, 1)) + return false; + + const u32 segment_offset = Common::swap32(purge_segment.offset); + const u32 segment_size = Common::swap32(purge_segment.size); + + if (segment_offset < offset_in_decompressed_data) + return false; + + const u32 blank_bytes = segment_offset - offset_in_decompressed_data; + std::memset(out_ptr, 0, blank_bytes); + out_ptr += blank_bytes; + + if (segment_size != 0 && !m_file.ReadBytes(out_ptr, segment_size)) + return false; + out_ptr += segment_size; + + offset_in_data += sizeof(PurgeSegment) + segment_size; + offset_in_decompressed_data = segment_offset + segment_size; + } + + if (offset_in_data != hash_offset || offset_in_decompressed_data > decompressed_data_size) + return false; + + std::memset(out_ptr, 0, decompressed_data_size - offset_in_decompressed_data); + + SHA1 expected_hash; + if (!m_file.ReadArray(&expected_hash, 1)) + return false; + + // TODO: Check hash + + return true; + } + } + + return false; +} + +bool WIAFileReader::ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, + u64 offset_in_data, u64 size_in_data, u8* out_ptr, + bool exception_list) +{ + if (m_compression_type == CompressionType::None) + { + if (!m_file.Seek(data_offset, SEEK_SET)) + return false; + + if (exception_list) + { + const std::optional exception_list_size = ReadExceptionListFromFile(); + if (!exception_list_size) + return false; + + data_size -= *exception_list_size; + } + + if (!m_file.Seek(offset_in_data, SEEK_CUR) || !m_file.ReadBytes(out_ptr, size_in_data)) + return false; + + return true; + } + else + { + // TODO: Caching + std::vector buffer(decompressed_data_size); + if (!ReadCompressedData(decompressed_data_size, data_offset, data_size, buffer.data(), + exception_list)) + { + return false; + } + std::memcpy(out_ptr, buffer.data() + offset_in_data, size_in_data); + return true; + } +} + +std::optional WIAFileReader::ReadExceptionListFromFile() +{ + u16 exceptions; + if (!m_file.ReadArray(&exceptions, 1)) + return std::nullopt; + + const u64 exception_list_size = Common::AlignUp( + sizeof(exceptions) + Common::swap16(exceptions) * sizeof(HashExceptionEntry), 4); + + if (!m_file.Seek(exception_list_size - sizeof(exceptions), SEEK_CUR)) + return std::nullopt; + + // TODO: Actually handle the exceptions + + return exception_list_size; +} + std::string WIAFileReader::VersionToString(u32 version) { const u8 a = version >> 24; diff --git a/Source/Core/DiscIO/WIABlob.h b/Source/Core/DiscIO/WIABlob.h index fd00520dd1..b355437d5b 100644 --- a/Source/Core/DiscIO/WIABlob.h +++ b/Source/Core/DiscIO/WIABlob.h @@ -6,6 +6,7 @@ #include #include +#include #include "Common/CommonTypes.h" #include "Common/File.h" @@ -43,6 +44,13 @@ private: bool ReadFromGroups(u64* offset, u64* size, u8** out_ptr, u64 chunk_size, u32 sector_size, u64 data_offset, u64 data_size, u32 group_index, u32 number_of_groups, bool exception_list); + bool ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, u8* out_ptr, + bool exception_list); + bool ReadCompressedData(u32 decompressed_data_size, u64 data_offset, u64 data_size, + u64 offset_in_data, u64 size_in_data, u8* out_ptr, bool exception_list); + + // Returns the number of bytes read + std::optional ReadExceptionListFromFile(); static std::string VersionToString(u32 version); @@ -128,9 +136,26 @@ private: SHA1 hash; }; static_assert(sizeof(HashExceptionEntry) == 0x16, "Wrong size for WIA hash exception entry"); + + struct PurgeSegment + { + u32 offset; + u32 size; + }; + static_assert(sizeof(PurgeSegment) == 0x08, "Wrong size for WIA purge segment"); #pragma pack(pop) + enum class CompressionType : u32 + { + None = 0, + Purge = 1, + Bzip2 = 2, + LZMA = 3, + LZMA2 = 4, + }; + bool m_valid; + CompressionType m_compression_type; File::IOFile m_file;