diff --git a/Lidgren.Network/NetConnection.Handshake.cs b/Lidgren.Network/NetConnection.Handshake.cs index 7620a3d..8c07963 100644 --- a/Lidgren.Network/NetConnection.Handshake.cs +++ b/Lidgren.Network/NetConnection.Handshake.cs @@ -361,6 +361,7 @@ namespace Lidgren.Network return; } break; + case NetMessageType.Disconnect: // ouch string reason = "Ouch"; @@ -374,6 +375,15 @@ namespace Lidgren.Network } ExecuteDisconnect(reason, false); break; + + case NetMessageType.Discovery: + m_peer.HandleIncomingDiscoveryRequest(now, m_remoteEndpoint, ptr, payloadLength); + return; + + case NetMessageType.DiscoveryResponse: + m_peer.HandleIncomingDiscoveryResponse(now, m_remoteEndpoint, ptr, payloadLength); + return; + default: m_peer.LogDebug("Unhandled type during handshake: " + tp + " length: " + payloadLength); break; diff --git a/Lidgren.Network/NetPeer.Internal.cs b/Lidgren.Network/NetPeer.Internal.cs index 1e7700c..0d5f44a 100644 --- a/Lidgren.Network/NetPeer.Internal.cs +++ b/Lidgren.Network/NetPeer.Internal.cs @@ -253,11 +253,27 @@ namespace Lidgren.Network // do handshake heartbeats if ((m_frameCounter % 3) == 0) { - foreach (NetConnection conn in m_handshakes.Values) + foreach (var kvp in m_handshakes) { + NetConnection conn = kvp.Value as NetConnection; +#if DEBUG + // sanity check + if (kvp.Key != kvp.Key) + LogWarning("Sanity fail! Connection in handshake list under wrong key!"); +#endif conn.UnconnectedHeartbeat(now); if (conn.m_status == NetConnectionStatus.Connected || conn.m_status == NetConnectionStatus.Disconnected) + { +#if DEBUG + // sanity check + if (conn.m_status == NetConnectionStatus.Disconnected && m_handshakes.ContainsKey(conn.RemoteEndpoint)) + { + LogWarning("Sanity fail! Handshakes list contained disconnected connection!"); + m_handshakes.Remove(conn.RemoteEndpoint); + } +#endif break; // collection has been modified + } } } @@ -454,6 +470,34 @@ namespace Lidgren.Network } while (m_socket.Available > 0); } + internal void HandleIncomingDiscoveryRequest(double now, IPEndPoint senderEndpoint, int ptr, int payloadByteLength) + { + if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryRequest)) + { + NetIncomingMessage dm = CreateIncomingMessage(NetIncomingMessageType.DiscoveryRequest, payloadByteLength); + if (payloadByteLength > 0) + Buffer.BlockCopy(m_receiveBuffer, ptr, dm.m_data, 0, payloadByteLength); + dm.m_receiveTime = now; + dm.m_bitLength = payloadByteLength * 8; + dm.m_senderEndpoint = senderEndpoint; + ReleaseMessage(dm); + } + } + + internal void HandleIncomingDiscoveryResponse(double now, IPEndPoint senderEndpoint, int ptr, int payloadByteLength) + { + if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryResponse)) + { + NetIncomingMessage dr = CreateIncomingMessage(NetIncomingMessageType.DiscoveryResponse, payloadByteLength); + if (payloadByteLength > 0) + Buffer.BlockCopy(m_receiveBuffer, ptr, dr.m_data, 0, payloadByteLength); + dr.m_receiveTime = now; + dr.m_bitLength = payloadByteLength * 8; + dr.m_senderEndpoint = senderEndpoint; + ReleaseMessage(dr); + } + } + private void ReceivedUnconnectedLibraryMessage(double now, IPEndPoint senderEndpoint, NetMessageType tp, int ptr, int payloadByteLength) { NetConnection shake; @@ -469,29 +513,10 @@ namespace Lidgren.Network switch (tp) { case NetMessageType.Discovery: - if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryRequest)) - { - NetIncomingMessage dm = CreateIncomingMessage(NetIncomingMessageType.DiscoveryRequest, payloadByteLength); - if (payloadByteLength > 0) - Buffer.BlockCopy(m_receiveBuffer, ptr, dm.m_data, 0, payloadByteLength); - dm.m_receiveTime = now; - dm.m_bitLength = payloadByteLength * 8; - dm.m_senderEndpoint = senderEndpoint; - ReleaseMessage(dm); - } + HandleIncomingDiscoveryRequest(now, senderEndpoint, ptr, payloadByteLength); return; - case NetMessageType.DiscoveryResponse: - if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryResponse)) - { - NetIncomingMessage dr = CreateIncomingMessage(NetIncomingMessageType.DiscoveryResponse, payloadByteLength); - if (payloadByteLength > 0) - Buffer.BlockCopy(m_receiveBuffer, ptr, dr.m_data, 0, payloadByteLength); - dr.m_receiveTime = now; - dr.m_bitLength = payloadByteLength * 8; - dr.m_senderEndpoint = senderEndpoint; - ReleaseMessage(dr); - } + HandleIncomingDiscoveryResponse(now, senderEndpoint, ptr, payloadByteLength); return; case NetMessageType.NatIntroduction: HandleNatIntroduction(ptr);