using System; using System.Collections.Concurrent; using System.Linq; using System.Net; using System.Net.Sockets; using System.Threading; namespace Hazel.Udp.FewerThreads { /// /// Listens for new UDP connections and creates UdpConnections for them. /// /// public class ThreadLimitedUdpConnectionListener : NetworkConnectionListener { private struct SendMessageInfo { public ByteSpan Span; public IPEndPoint Recipient; } private struct ReceiveMessageInfo { public MessageReader Message; public IPEndPoint Sender; public ConnectionId ConnectionId; } private const int SendReceiveBufferSize = 1024 * 1024; private Socket socket; protected ILogger Logger; private Thread reliablePacketThread; private Thread receiveThread; private Thread sendThread; private HazelThreadPool processThreads; public bool ReceiveThreadRunning => this.receiveThread.ThreadState == ThreadState.Running; public struct ConnectionId : IEquatable { public IPEndPoint EndPoint; public int Serial; public static ConnectionId Create(IPEndPoint endPoint, int serial) { return new ConnectionId{ EndPoint = endPoint, Serial = serial, }; } public bool Equals(ConnectionId other) { return this.Serial == other.Serial && this.EndPoint.Equals(other.EndPoint) ; } public override bool Equals(object obj) { if (obj is ConnectionId) { return this.Equals((ConnectionId)obj); } return false; } public override int GetHashCode() { ///NOTE(mendsley): We're only hashing the endpoint /// here, as the common case will have one /// connection per address+port tuple. return this.EndPoint.GetHashCode(); } } protected ConcurrentDictionary allConnections = new ConcurrentDictionary(); private BlockingCollection receiveQueue; private BlockingCollection sendQueue = new BlockingCollection(); public int MaxAge { get { var now = DateTime.UtcNow; TimeSpan max = new TimeSpan(); foreach (var con in allConnections.Values) { var val = now - con.CreationTime; if (val > max) max = val; } return (int)max.TotalSeconds; } } public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count; public override int ConnectionCount { get { return this.allConnections.Count; } } public override int SendQueueLength { get { return this.sendQueue.Count; } } public override int ReceiveQueueLength { get { return this.receiveQueue.Count; } } private bool isActive; public ThreadLimitedUdpConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4) { this.Logger = logger; this.EndPoint = endPoint; this.IPMode = ipMode; this.receiveQueue = new BlockingCollection(10000); this.socket = UdpConnection.CreateSocket(this.IPMode); this.socket.ExclusiveAddressUse = true; this.socket.Blocking = false; this.socket.ReceiveBufferSize = SendReceiveBufferSize; this.socket.SendBufferSize = SendReceiveBufferSize; this.reliablePacketThread = new Thread(ManageReliablePackets); this.sendThread = new Thread(SendLoop); this.receiveThread = new Thread(ReceiveLoop); this.processThreads = new HazelThreadPool(numWorkers, ProcessingLoop); } ~ThreadLimitedUdpConnectionListener() { this.Dispose(false); } // This is just for booting people after they've been connected a certain amount of time... public void DisconnectOldConnections(TimeSpan maxAge, MessageWriter disconnectMessage) { var now = DateTime.UtcNow; foreach (var conn in this.allConnections.Values) { if (now - conn.CreationTime > maxAge) { conn.Disconnect("Stale Connection", disconnectMessage); } } } private void ManageReliablePackets() { while (this.isActive) { foreach (var kvp in this.allConnections) { var sock = kvp.Value; sock.ManageReliablePackets(); } Thread.Sleep(100); } } public override void Start() { try { socket.Bind(EndPoint); } catch (SocketException e) { throw new HazelException("Could not start listening as a SocketException occurred", e); } this.isActive = true; this.reliablePacketThread.Start(); this.sendThread.Start(); this.receiveThread.Start(); this.processThreads.Start(); } private void ReceiveLoop() { while (this.isActive) { if (this.socket.Poll(1000, SelectMode.SelectRead)) { if (!isActive) break; EndPoint remoteEP = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port); var message = MessageReader.GetSized(this.ReceiveBufferSize); try { message.Length = socket.ReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP); } catch (SocketException sx) { message.Recycle(); if (sx.SocketErrorCode == SocketError.NotConnected) { this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected); return; } this.Logger.WriteError("Socket Ex in ReceiveLoop: " + sx.Message); continue; } catch (Exception ex) { message.Recycle(); this.Logger.WriteError("Stopped due to: " + ex.Message); return; } ConnectionId connectionId = ConnectionId.Create((IPEndPoint)remoteEP, 0); this.ProcessIncomingMessageFromOtherThread(message, (IPEndPoint)remoteEP, connectionId); } } } private void ProcessingLoop() { foreach (ReceiveMessageInfo msg in this.receiveQueue.GetConsumingEnumerable()) { try { this.ReadCallback(msg.Message, msg.Sender, msg.ConnectionId); } catch { } } } protected void ProcessIncomingMessageFromOtherThread(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) { var info = new ReceiveMessageInfo() { Message = message, Sender = remoteEndPoint, ConnectionId = connectionId }; if (!this.receiveQueue.TryAdd(info)) { this.Statistics.AddReceiveThreadBlocking(); this.receiveQueue.Add(info); } } private void SendLoop() { foreach (SendMessageInfo msg in this.sendQueue.GetConsumingEnumerable()) { try { if (this.socket.Poll(Timeout.Infinite, SelectMode.SelectWrite)) { this.socket.SendTo(msg.Span.GetUnderlyingArray(), msg.Span.Offset, msg.Span.Length, SocketFlags.None, msg.Recipient); this.Statistics.AddBytesSent(msg.Span.Length - msg.Span.Offset); } else { this.Logger.WriteError("Socket is no longer able to send"); break; } } catch (Exception e) { this.Logger.WriteError("Error in loop while sending: " + e.Message); Thread.Sleep(1); } } } protected virtual void ReadCallback(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) { int bytesReceived = message.Length; bool aware = true; bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello; // If we're aware of this connection use the one already // If this is a new client then connect with them! ThreadLimitedUdpServerConnection connection; if (!this.allConnections.TryGetValue(connectionId, out connection)) { lock (this.allConnections) { if (!this.allConnections.TryGetValue(connectionId, out connection)) { // Check for malformed connection attempts if (!isHello) { message.Recycle(); return; } if (AcceptConnection != null) { if (!AcceptConnection(remoteEndPoint, message.Buffer, out var response)) { message.Recycle(); if (response != null) { SendDataRaw(response, remoteEndPoint); } return; } } aware = false; connection = new ThreadLimitedUdpServerConnection(this, connectionId, remoteEndPoint, this.IPMode, this.Logger); if (!this.allConnections.TryAdd(connectionId, connection)) { throw new HazelException("Failed to add a connection. This should never happen."); } } } } // If it's a new connection invoke the NewConnection event. // This needs to happen before handling the message because in localhost scenarios, the ACK and // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers if (!aware) { // Skip header and hello byte; message.Offset = 4; message.Length = bytesReceived - 4; message.Position = 0; try { this.InvokeNewConnection(message, connection); } catch (Exception e) { this.Logger.WriteError("NewConnection handler threw: " + e); } } // Inform the connection of the buffer (new connections need to send an ack back to client) connection.HandleReceive(message, bytesReceived); } internal void SendDataRaw(byte[] response, IPEndPoint remoteEndPoint) { QueueRawData(response, remoteEndPoint); } protected virtual void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint) { this.sendQueue.TryAdd(new SendMessageInfo() { Span = span, Recipient = remoteEndPoint }); } /// /// Removes a virtual connection from the list. /// /// Connection key of the virtual connection. internal bool RemoveConnectionTo(ConnectionId connectionId) { return this.allConnections.TryRemove(connectionId, out _); } /// /// This is after all messages could be sent. Clean up anything extra. /// internal virtual void RemovePeerRecord(ConnectionId connectionId) { } protected override void Dispose(bool disposing) { foreach (var kvp in this.allConnections) { kvp.Value.Dispose(); } bool wasActive = this.isActive; this.isActive = false; // Flush outgoing packets this.sendQueue?.CompleteAdding(); if (wasActive) { this.sendThread.Join(); } try { this.socket.Shutdown(SocketShutdown.Both); } catch { } try { this.socket.Close(); } catch { } try { this.socket.Dispose(); } catch { } this.receiveQueue?.CompleteAdding(); if (wasActive) { this.reliablePacketThread.Join(); this.receiveThread.Join(); this.processThreads.Join(); } this.receiveQueue?.Dispose(); this.receiveQueue = null; this.sendQueue?.Dispose(); this.sendQueue = null; base.Dispose(disposing); } } }