From da86c2bcc6fbe6e1fef40d2d19e51fc1df842431 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Thu, 1 Feb 2024 22:05:40 +0400 Subject: [PATCH] BBA/HLE: Don't assume connect is successful --- Source/Core/Core/HW/EXI/BBA/BuiltIn.cpp | 115 ++++++++++++++++++++---- Source/Core/Core/HW/EXI/BBA/BuiltIn.h | 15 ++++ 2 files changed, 111 insertions(+), 19 deletions(-) diff --git a/Source/Core/Core/HW/EXI/BBA/BuiltIn.cpp b/Source/Core/Core/HW/EXI/BBA/BuiltIn.cpp index 04bb763407..be77eb6d59 100644 --- a/Source/Core/Core/HW/EXI/BBA/BuiltIn.cpp +++ b/Source/Core/Core/HW/EXI/BBA/BuiltIn.cpp @@ -6,6 +6,7 @@ #ifdef _WIN32 #include #else +#include #include #include #endif @@ -18,8 +19,6 @@ #include "Core/HW/EXI/EXI_Device.h" #include "Core/HW/EXI/EXI_DeviceEthernet.h" -namespace ExpansionInterface -{ namespace { u64 GetTickCountStd() @@ -27,7 +26,12 @@ u64 GetTickCountStd() using namespace std::chrono; return duration_cast(steady_clock::now().time_since_epoch()).count(); } +} // namespace +namespace ExpansionInterface +{ +namespace +{ std::vector BuildFINFrame(StackRef* ref) { const Common::TCPPacket result(ref->bba_mac, ref->my_mac, ref->from, ref->to, ref->seq_num, @@ -255,6 +259,9 @@ CEXIETHERNET::BuiltInBBAInterface::TryGetDataFromSocket(StackRef* ref) } case IPPROTO_TCP: + if (!ref->tcp_socket.Connected(ref)) + return std::nullopt; + sf::Socket::Status st = sf::Socket::Status::Done; TcpBuffer* tcp_buffer = nullptr; for (auto& tcp_buf : ref->tcp_buffers) @@ -357,26 +364,11 @@ void CEXIETHERNET::BuiltInBBAInterface::HandleTCPFrame(const Common::TCPPacket& ref->bba_mac = m_current_mac; ref->my_mac = ResolveAddress(destination_ip); ref->tcp_socket.setBlocking(false); - - // reply with a sin_ack - Common::TCPPacket result(ref->bba_mac, ref->my_mac, ref->from, ref->to, ref->seq_num, - ref->ack_num, TCP_FLAG_SIN | TCP_FLAG_ACK); - - result.tcp_options = { - 0x02, 0x04, 0x05, 0xb4, // Maximum segment size: 1460 bytes - 0x01, 0x01, 0x01, 0x01 // NOPs - }; - - ref->seq_num++; - target = sf::IpAddress(ntohl(destination_ip)); - ref->tcp_socket.Connect(target, ntohs(tcp_header.destination_port), m_current_ip); ref->ready = false; ref->ip = Common::BitCast(ip_header.destination_addr); - ref->tcp_buffers[0].data = result.Build(); - ref->tcp_buffers[0].seq_id = ref->seq_num - 1; - ref->tcp_buffers[0].tick = GetTickCountStd() - 900; // delay - ref->tcp_buffers[0].used = true; + target = sf::IpAddress(ntohl(destination_ip)); + ref->tcp_socket.Connect(target, ntohs(tcp_header.destination_port), m_current_ip); } else { @@ -808,6 +800,7 @@ sf::Socket::Status BbaTcpSocket::Connect(const sf::IpAddress& dest, u16 port, u3 addr.sin_family = AF_INET; addr.sin_port = 0; ::bind(getHandle(), reinterpret_cast(&addr), sizeof(addr)); + m_connecting_state = ConnectingState::Connecting; return this->connect(dest, port); } @@ -833,6 +826,90 @@ sf::Socket::Status BbaTcpSocket::GetSockName(sockaddr_in* addr) const return sf::Socket::Status::Done; } +bool BbaTcpSocket::Connected(StackRef* ref) +{ + // Called by ReadThreadHandler's TryGetDataFromSocket + // TODO: properly handle error state + switch (m_connecting_state) + { + case ConnectingState::Connected: + return true; + case ConnectingState::Connecting: + { + const int fd = getHandle(); + const s32 nfds = fd + 1; + fd_set read_fds; + fd_set write_fds; + fd_set except_fds; + struct timeval t = {0, 0}; + FD_ZERO(&read_fds); + FD_ZERO(&write_fds); + FD_ZERO(&except_fds); + FD_SET(fd, &write_fds); + FD_SET(fd, &except_fds); + + if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0) + { + ERROR_LOG_FMT(SP1, "Failed to get BBA socket connection state: {}", + Common::StrNetworkError()); + break; + } + + if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0) + break; + + s32 error = 0; + socklen_t len = sizeof(error); + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) + { + ERROR_LOG_FMT(SP1, "Failed to get BBA socket error state: {}", Common::StrNetworkError()); + break; + } + + if (error != 0) + { + ERROR_LOG_FMT(SP1, "BBA connect failed (err={}): {}", error, + Common::DecodeNetworkError(error)); + m_connecting_state = ConnectingState::Error; + break; + } + + // Get peername to ensure the socket is connected + sockaddr_in peer; + socklen_t peer_len = sizeof(peer); + if (getpeername(fd, reinterpret_cast(&peer), &peer_len) != 0) + { + ERROR_LOG_FMT(SP1, "BBA connect failed to get peername: {}", Common::StrNetworkError()); + m_connecting_state = ConnectingState::Error; + break; + } + + // Create the resulting SYN ACK packet + m_connecting_state = ConnectingState::Connected; + INFO_LOG_FMT(SP1, "BBA connect succeeded"); + + Common::TCPPacket result(ref->bba_mac, ref->my_mac, ref->from, ref->to, ref->seq_num, + ref->ack_num, TCP_FLAG_SIN | TCP_FLAG_ACK); + + result.tcp_options = { + 0x02, 0x04, 0x05, 0xb4, // Maximum segment size: 1460 bytes + 0x01, 0x01, 0x01, 0x01 // NOPs + }; + + ref->seq_num++; + ref->tcp_buffers[0].data = result.Build(); + ref->tcp_buffers[0].seq_id = ref->seq_num - 1; + ref->tcp_buffers[0].tick = GetTickCountStd() - 900; // delay + ref->tcp_buffers[0].used = true; + + break; + } + default: + break; + } + return false; +} + BbaUdpSocket::BbaUdpSocket() = default; sf::Socket::Status BbaUdpSocket::Bind(u16 port, u32 net_ip) diff --git a/Source/Core/Core/HW/EXI/BBA/BuiltIn.h b/Source/Core/Core/HW/EXI/BBA/BuiltIn.h index afcf39d293..981b2e46e2 100644 --- a/Source/Core/Core/HW/EXI/BBA/BuiltIn.h +++ b/Source/Core/Core/HW/EXI/BBA/BuiltIn.h @@ -32,6 +32,8 @@ struct TcpBuffer std::vector data; }; +struct StackRef; + // Socket helper classes to ensure network interface consistency. // // If the socket isn't bound, the system will pick the interface to use automatically. @@ -45,6 +47,19 @@ public: sf::Socket::Status Connect(const sf::IpAddress& dest, u16 port, u32 net_ip); sf::Socket::Status GetPeerName(sockaddr_in* addr) const; sf::Socket::Status GetSockName(sockaddr_in* addr) const; + + bool Connected(StackRef* ref); + +private: + enum class ConnectingState + { + None, + Connecting, + Connected, + Error + }; + + ConnectingState m_connecting_state = ConnectingState::None; }; class BbaUdpSocket : public sf::UdpSocket