diff --git a/Lidgren.Network/NetConnection.cs b/Lidgren.Network/NetConnection.cs index ab013fc..1de2d89 100644 --- a/Lidgren.Network/NetConnection.cs +++ b/Lidgren.Network/NetConnection.cs @@ -79,6 +79,11 @@ namespace Lidgren.Network /// public IPEndPoint RemoteEndpoint { get { return m_remoteEndpoint; } } + /// + /// Gets the owning NetPeer instance + /// + public NetPeer Owner { get { return m_owner; } } + internal NetConnection(NetPeer owner, IPEndPoint remoteEndpoint) { m_owner = owner; diff --git a/Lidgren.Network/NetNatIntroduction.cs b/Lidgren.Network/NetNatIntroduction.cs index 0861629..838e79c 100644 --- a/Lidgren.Network/NetNatIntroduction.cs +++ b/Lidgren.Network/NetNatIntroduction.cs @@ -48,6 +48,8 @@ namespace Lidgren.Network string token = tmp.ReadString(); bool isHost = (hostByte != 0); + LogDebug("NAT introduction received; we are designated " + (isHost ? "host" : "client")); + NetOutgoingMessage punch; if (!isHost && m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.NatIntroductionSuccess) == false) @@ -56,21 +58,13 @@ namespace Lidgren.Network // send internal punch punch = CreateMessage(1); punch.Write(hostByte); - if (hostByte == 0) - { - // only client needs to send token - punch.Write(token); - } + punch.Write(token); SendUnconnectedLibraryMessage(punch, NetMessageLibraryType.NatPunchMessage, remoteInternal); // send external punch punch = CreateMessage(1); punch.Write(hostByte); - if (hostByte == 0) - { - // only client needs to send token - punch.Write(token); - } + punch.Write(token); SendUnconnectedLibraryMessage(punch, NetMessageLibraryType.NatPunchMessage, remoteExternal); } @@ -82,11 +76,17 @@ namespace Lidgren.Network NetIncomingMessage tmp = new NetIncomingMessage(m_receiveBuffer, 1000); // never mind length tmp.Position = (ptr * 8); - byte hostByte = tmp.ReadByte(); - if (hostByte != 0) + byte fromHostByte = tmp.ReadByte(); + if (fromHostByte == 0) + { + // it's from client + LogDebug("NAT punch received from " + senderEndpoint + " we're host, so we ignore this"); return; // don't alert hosts about nat punch successes; only clients + } string token = tmp.ReadString(); + LogDebug("NAT punch received from " + senderEndpoint + " we're client, so we've succeeded - token is " + token); + // // Release punch success to client; enabling him to Connect() to msg.SenderIPEndPoint if token is ok // diff --git a/Lidgren.Network/NetPeer.Internal.cs b/Lidgren.Network/NetPeer.Internal.cs index d6ab3df..8678cee 100644 --- a/Lidgren.Network/NetPeer.Internal.cs +++ b/Lidgren.Network/NetPeer.Internal.cs @@ -32,7 +32,7 @@ namespace Lidgren.Network internal Socket m_socket; internal byte[] m_macAddressBytes; private int m_listenPort; - private readonly AutoResetEvent m_messageReceivedEvent = new AutoResetEvent(false); + private AutoResetEvent m_messageReceivedEvent = new AutoResetEvent(false); private readonly NetQueue m_releasedIncomingMessages = new NetQueue(16); private readonly NetQueue m_unsentUnconnectedMessage = new NetQueue(4); @@ -178,7 +178,10 @@ namespace Lidgren.Network m_socket.Close(2); // 2 seconds timeout } if (m_messageReceivedEvent != null) + { m_messageReceivedEvent.Close(); + m_messageReceivedEvent = null; + } } finally { @@ -400,122 +403,114 @@ namespace Lidgren.Network { VerifyNetworkThread(); - if (libType != NetMessageLibraryType.Connect && libType != NetMessageLibraryType.Discovery && libType != NetMessageLibraryType.DiscoveryResponse) - { - LogWarning("Received unconnected library message of type " + libType); - return; - } - int payloadLengthBytes = NetUtility.BytesToHoldBits(payloadLengthBits); - // - // Handle nat introduction - // - if (libType == NetMessageLibraryType.NatIntroduction) - HandleNatIntroduction(ptr); - - // - // Handle Discovery - // - if (libType == NetMessageLibraryType.Discovery) + switch (libType) { - if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryRequest)) - { - NetIncomingMessage dm = CreateIncomingMessage(NetIncomingMessageType.DiscoveryRequest, payloadLengthBytes); - if (payloadLengthBytes > 0) - Buffer.BlockCopy(m_receiveBuffer, ptr, dm.m_data, 0, payloadLengthBytes); - dm.m_bitLength = payloadLengthBits; - dm.m_senderEndpoint = senderEndpoint; - ReleaseMessage(dm); - } - return; - } - - if (libType == NetMessageLibraryType.DiscoveryResponse) - { - if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryResponse)) - { - NetIncomingMessage dr = CreateIncomingMessage(NetIncomingMessageType.DiscoveryResponse, payloadLengthBytes); - if (payloadLengthBytes > 0) - Buffer.BlockCopy(m_receiveBuffer, ptr, dr.m_data, 0, payloadLengthBytes); - dr.m_bitLength = payloadLengthBits; - dr.m_senderEndpoint = senderEndpoint; - ReleaseMessage(dr); - } - return; - } - - // - // Handle NetMessageLibraryType.Connect - // - - if (!m_configuration.m_acceptIncomingConnections) - { - LogWarning("Connect received; but we're not accepting incoming connections!"); - return; - } - - string appIdent; - long remoteUniqueIdentifier = 0; - NetIncomingMessage approval = null; - try - { - NetIncomingMessage reader = new NetIncomingMessage(); - - reader.m_data = GetStorage(payloadLengthBytes); - Buffer.BlockCopy(m_receiveBuffer, ptr, reader.m_data, 0, payloadLengthBytes); - ptr += payloadLengthBytes; - reader.m_bitLength = payloadLengthBits; - appIdent = reader.ReadString(); - remoteUniqueIdentifier = reader.ReadInt64(); - - int approvalBitLength = (int)reader.ReadVariableUInt32(); - if (approvalBitLength > 0) - { - int approvalByteLength = NetUtility.BytesToHoldBits(approvalBitLength); - if (approvalByteLength < m_configuration.MaximumTransmissionUnit) + case NetMessageLibraryType.NatPunchMessage: + HandleNatPunch(ptr, senderEndpoint); + break; + case NetMessageLibraryType.NatIntroduction: + HandleNatIntroduction(ptr); + break; + case NetMessageLibraryType.Discovery: + if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryRequest)) { - approval = CreateIncomingMessage(NetIncomingMessageType.ConnectionApproval, approvalByteLength); - reader.ReadBits(approval.m_data, 0, approvalBitLength); - approval.m_bitLength = approvalBitLength; + NetIncomingMessage dm = CreateIncomingMessage(NetIncomingMessageType.DiscoveryRequest, payloadLengthBytes); + if (payloadLengthBytes > 0) + Buffer.BlockCopy(m_receiveBuffer, ptr, dm.m_data, 0, payloadLengthBytes); + dm.m_bitLength = payloadLengthBits; + dm.m_senderEndpoint = senderEndpoint; + ReleaseMessage(dm); } - } - } - catch (Exception ex) - { - // malformed connect packet - LogWarning("Malformed connect packet from " + senderEndpoint + " - " + ex.ToString()); - return; - } - if (appIdent.Equals(m_configuration.AppIdentifier) == false) - { - // wrong app ident - LogWarning("Connect received with wrong appidentifier (need '" + m_configuration.AppIdentifier + "' found '" + appIdent + "') from " + senderEndpoint); - return; + break; + case NetMessageLibraryType.DiscoveryResponse: + if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.DiscoveryResponse)) + { + NetIncomingMessage dr = CreateIncomingMessage(NetIncomingMessageType.DiscoveryResponse, payloadLengthBytes); + if (payloadLengthBytes > 0) + Buffer.BlockCopy(m_receiveBuffer, ptr, dr.m_data, 0, payloadLengthBytes); + dr.m_bitLength = payloadLengthBits; + dr.m_senderEndpoint = senderEndpoint; + ReleaseMessage(dr); + } + break; + + case NetMessageLibraryType.Connect: + + + if (!m_configuration.m_acceptIncomingConnections) + { + LogWarning("Connect received; but we're not accepting incoming connections!"); + break; + } + + string appIdent; + long remoteUniqueIdentifier = 0; + NetIncomingMessage approval = null; + try + { + NetIncomingMessage reader = new NetIncomingMessage(); + + reader.m_data = GetStorage(payloadLengthBytes); + Buffer.BlockCopy(m_receiveBuffer, ptr, reader.m_data, 0, payloadLengthBytes); + ptr += payloadLengthBytes; + reader.m_bitLength = payloadLengthBits; + appIdent = reader.ReadString(); + remoteUniqueIdentifier = reader.ReadInt64(); + + int approvalBitLength = (int)reader.ReadVariableUInt32(); + if (approvalBitLength > 0) + { + int approvalByteLength = NetUtility.BytesToHoldBits(approvalBitLength); + if (approvalByteLength < m_configuration.MaximumTransmissionUnit) + { + approval = CreateIncomingMessage(NetIncomingMessageType.ConnectionApproval, approvalByteLength); + reader.ReadBits(approval.m_data, 0, approvalBitLength); + approval.m_bitLength = approvalBitLength; + } + } + } + catch (Exception ex) + { + // malformed connect packet + LogWarning("Malformed connect packet from " + senderEndpoint + " - " + ex.ToString()); + break; + } + + if (appIdent.Equals(m_configuration.AppIdentifier) == false) + { + // wrong app ident + LogWarning("Connect received with wrong appidentifier (need '" + m_configuration.AppIdentifier + "' found '" + appIdent + "') from " + senderEndpoint); + break; + } + + // ok, someone wants to connect to us, and we're accepting connections! + if (m_connections.Count >= m_configuration.MaximumConnections) + { + HandleServerFull(senderEndpoint); + break; + } + + NetConnection conn = new NetConnection(this, senderEndpoint); + conn.m_connectionInitiator = false; + conn.m_connectInitationTime = NetTime.Now; + conn.m_remoteUniqueIdentifier = remoteUniqueIdentifier; + + if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.ConnectionApproval)) + { + // do connection approval before accepting this connection + AddPendingConnection(conn, approval); + break; + } + + AcceptConnection(conn); + break; + default: + LogWarning("Received unconnected library message of type " + libType); + break; } - - // ok, someone wants to connect to us, and we're accepting connections! - if (m_connections.Count >= m_configuration.MaximumConnections) - { - HandleServerFull(senderEndpoint); - return; - } - - NetConnection conn = new NetConnection(this, senderEndpoint); - conn.m_connectionInitiator = false; - conn.m_connectInitationTime = NetTime.Now; - conn.m_remoteUniqueIdentifier = remoteUniqueIdentifier; - - if (m_configuration.IsMessageTypeEnabled(NetIncomingMessageType.ConnectionApproval)) - { - // do connection approval before accepting this connection - AddPendingConnection(conn, approval); - return; - } - - AcceptConnection(conn); - return; } private void HandleUnconnectedUserMessage(int ptr, int payloadLengthBits, IPEndPoint senderEndpoint)