Merge pull request #11382 from skyfloogle/traversal-fix-2

Traversal: Use low TTL for probe packet
This commit is contained in:
Mai
2023-11-30 18:03:50 -05:00
committed by GitHub
13 changed files with 284 additions and 33 deletions

View File

@ -14,8 +14,9 @@
namespace Common
{
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
: m_NetHost(netHost), m_Server(server), m_port(port)
TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port,
const u16 port_alt)
: m_NetHost(netHost), m_Server(server), m_port(port), m_portAlt(port_alt)
{
netHost->intercept = TraversalClient::InterceptCallback;
@ -146,6 +147,8 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
{
if (it->packet.requestId == packet->requestId)
{
if (packet->requestId == m_TestRequestId)
HandleTraversalTest();
m_OutgoingTraversalPackets.erase(it);
break;
}
@ -161,6 +164,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
}
m_HostId = packet->helloFromServer.yourHostId;
m_external_address = packet->helloFromServer.yourAddress;
NewTraversalTest();
m_State = State::Connected;
if (m_Client)
m_Client->OnTraversalStateChanged();
@ -175,7 +179,18 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet)
ENetBuffer buf;
buf.data = message;
buf.dataLength = sizeof(message) - 1;
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
if (m_ttlReady)
{
int oldttl;
enet_socket_get_option(m_NetHost->socket, ENET_SOCKOPT_TTL, &oldttl);
enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, m_ttl);
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
enet_socket_set_option(m_NetHost->socket, ENET_SOCKOPT_TTL, oldttl);
}
else
{
enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
}
}
else
{
@ -231,12 +246,15 @@ void TraversalClient::OnFailure(FailureReason reason)
void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
{
bool testPacket =
m_TestSocket != ENET_SOCKET_NULL && info->packet.type == TraversalPacketType::TestPlease;
info->sendTime = enet_time_get();
info->tries++;
ENetBuffer buf;
buf.data = &info->packet;
buf.dataLength = sizeof(info->packet);
if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
if (enet_socket_send(testPacket ? m_TestSocket : m_NetHost->socket, &m_ServerAddress, &buf, 1) ==
-1)
OnFailure(FailureReason::SocketSendError);
}
@ -275,6 +293,112 @@ void TraversalClient::HandlePing()
}
}
void TraversalClient::NewTraversalTest()
{
// create test socket
if (m_TestSocket != ENET_SOCKET_NULL)
enet_socket_destroy(m_TestSocket);
m_TestSocket = enet_socket_create(ENET_SOCKET_TYPE_DATAGRAM);
ENetAddress addr = {ENET_HOST_ANY, 0};
if (m_TestSocket == ENET_SOCKET_NULL || enet_socket_bind(m_TestSocket, &addr) < 0)
{
// error, abort
if (m_TestSocket != ENET_SOCKET_NULL)
{
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
}
return;
}
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_NONBLOCK, 1);
// create holepunch packet
TraversalPacket packet = {};
packet.type = TraversalPacketType::Ping;
packet.ping.hostId = m_HostId;
packet.requestId = Common::Random::GenerateValue<TraversalRequestId>();
// create buffer
ENetBuffer buf;
buf.data = &packet;
buf.dataLength = sizeof(packet);
// send to alt port
ENetAddress altAddress = m_ServerAddress;
altAddress.port = m_portAlt;
// set up ttl and send
int oldttl;
enet_socket_get_option(m_TestSocket, ENET_SOCKOPT_TTL, &oldttl);
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, m_ttl);
if (enet_socket_send(m_TestSocket, &altAddress, &buf, 1) == -1)
{
// error, abort
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
return;
}
enet_socket_set_option(m_TestSocket, ENET_SOCKOPT_TTL, oldttl);
// send the test request
packet.type = TraversalPacketType::TestPlease;
m_TestRequestId = SendTraversalPacket(packet);
}
void TraversalClient::HandleTraversalTest()
{
if (m_TestSocket != ENET_SOCKET_NULL)
{
// check for packet on test socket (with timeout)
u32 deadline = enet_time_get() + 50;
u32 waitCondition;
do
{
waitCondition = ENET_SOCKET_WAIT_RECEIVE | ENET_SOCKET_WAIT_INTERRUPT;
u32 currentTime = enet_time_get();
if (currentTime > deadline ||
enet_socket_wait(m_TestSocket, &waitCondition, deadline - currentTime) != 0)
{
// error or timeout, exit the loop and assume test failure
waitCondition = 0;
break;
}
else if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
{
// try reading the packet and see if it's relevant
ENetAddress raddr;
TraversalPacket packet;
ENetBuffer buf;
buf.data = &packet;
buf.dataLength = sizeof(packet);
int rv = enet_socket_receive(m_TestSocket, &raddr, &buf, 1);
if (rv < 0)
{
// error, exit the loop and assume test failure
waitCondition = 0;
break;
}
else if (rv < sizeof(packet) || raddr.host != m_ServerAddress.host ||
raddr.host != m_portAlt || packet.requestId != m_TestRequestId)
{
// irrelevant packet, ignore
continue;
}
}
} while (waitCondition & ENET_SOCKET_WAIT_INTERRUPT);
// regardless of what happens next, we can throw out the socket
enet_socket_destroy(m_TestSocket);
m_TestSocket = ENET_SOCKET_NULL;
if (waitCondition & ENET_SOCKET_WAIT_RECEIVE)
{
// success, we can stop now
m_ttlReady = true;
m_Client->OnTtlDetermined(m_ttl);
}
else
{
// fail, increment and retry
if (++m_ttl < 32)
NewTraversalTest();
}
}
}
TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
{
OutgoingTraversalPacketInfo info;
@ -313,15 +437,19 @@ ENet::ENetHostPtr g_MainNetHost;
// explicitly requested.
static std::string g_OldServer;
static u16 g_OldServerPort;
static u16 g_OldServerPortAlt;
static u16 g_OldListenPort;
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 listen_port)
bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 server_port_alt,
u16 listen_port)
{
if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer ||
server_port != g_OldServerPort || listen_port != g_OldListenPort)
server_port != g_OldServerPort || server_port_alt != g_OldServerPortAlt ||
listen_port != g_OldListenPort)
{
g_OldServer = server;
g_OldServerPort = server_port;
g_OldServerPortAlt = server_port_alt;
g_OldListenPort = listen_port;
ENetAddress addr = {ENET_HOST_ANY, listen_port};
@ -337,7 +465,8 @@ bool EnsureTraversalClient(const std::string& server, u16 server_port, u16 liste
}
host->mtu = std::min(host->mtu, NetPlay::MAX_ENET_MTU);
g_MainNetHost = std::move(host);
g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, server_port));
g_TraversalClient.reset(
new TraversalClient(g_MainNetHost.get(), server, server_port, server_port_alt));
}
return true;
}