diff options
Diffstat (limited to 'Tools/Hazel-Networking/Hazel/FewerThreads')
3 files changed, 556 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs new file mode 100644 index 0000000..fb36b00 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Hazel +{ + internal class HazelThreadPool + { + private Thread[] threads; + + public HazelThreadPool(int numThreads, ThreadStart action) + { + this.threads = new Thread[numThreads]; + for (int i = 0; i < this.threads.Length; ++i) + { + this.threads[i] = new Thread(action); + } + } + + public void Start() + { + for (int i = 0; i < this.threads.Length; ++i) + { + this.threads[i].Start(); + } + } + + public void Join() + { + for (int i = 0; i < this.threads.Length; ++i) + { + var thread = this.threads[i]; + try + { + thread.Join(); + } + catch { } + } + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs new file mode 100644 index 0000000..e37be45 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs @@ -0,0 +1,402 @@ +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +namespace Hazel.Udp.FewerThreads +{ + /// <summary> + /// Listens for new UDP connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class ThreadLimitedUdpConnectionListener : NetworkConnectionListener + { + private struct SendMessageInfo + { + public ByteSpan Span; + public IPEndPoint Recipient; + } + + private struct ReceiveMessageInfo + { + public MessageReader Message; + public IPEndPoint Sender; + public ConnectionId ConnectionId; + } + + private const int SendReceiveBufferSize = 1024 * 1024; + + private Socket socket; + protected ILogger Logger; + + private Thread reliablePacketThread; + private Thread receiveThread; + private Thread sendThread; + private HazelThreadPool processThreads; + + public bool ReceiveThreadRunning => this.receiveThread.ThreadState == ThreadState.Running; + + public struct ConnectionId : IEquatable<ConnectionId> + { + public IPEndPoint EndPoint; + public int Serial; + + public static ConnectionId Create(IPEndPoint endPoint, int serial) + { + return new ConnectionId{ + EndPoint = endPoint, + Serial = serial, + }; + } + + public bool Equals(ConnectionId other) + { + return this.Serial == other.Serial + && this.EndPoint.Equals(other.EndPoint) + ; + } + + public override bool Equals(object obj) + { + if (obj is ConnectionId) + { + return this.Equals((ConnectionId)obj); + } + + return false; + } + + public override int GetHashCode() + { + ///NOTE(mendsley): We're only hashing the endpoint + /// here, as the common case will have one + /// connection per address+port tuple. + return this.EndPoint.GetHashCode(); + } + } + + protected ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection> allConnections = new ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection>(); + + private BlockingCollection<ReceiveMessageInfo> receiveQueue; + private BlockingCollection<SendMessageInfo> sendQueue = new BlockingCollection<SendMessageInfo>(); + + public int MaxAge + { + get + { + var now = DateTime.UtcNow; + TimeSpan max = new TimeSpan(); + foreach (var con in allConnections.Values) + { + var val = now - con.CreationTime; + if (val > max) max = val; + } + + return (int)max.TotalSeconds; + } + } + + public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count; + public override int ConnectionCount { get { return this.allConnections.Count; } } + public override int SendQueueLength { get { return this.sendQueue.Count; } } + public override int ReceiveQueueLength { get { return this.receiveQueue.Count; } } + + private bool isActive; + + public ThreadLimitedUdpConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4) + { + this.Logger = logger; + this.EndPoint = endPoint; + this.IPMode = ipMode; + + this.receiveQueue = new BlockingCollection<ReceiveMessageInfo>(10000); + + this.socket = UdpConnection.CreateSocket(this.IPMode); + this.socket.ExclusiveAddressUse = true; + this.socket.Blocking = false; + + this.socket.ReceiveBufferSize = SendReceiveBufferSize; + this.socket.SendBufferSize = SendReceiveBufferSize; + + this.reliablePacketThread = new Thread(ManageReliablePackets); + this.sendThread = new Thread(SendLoop); + this.receiveThread = new Thread(ReceiveLoop); + this.processThreads = new HazelThreadPool(numWorkers, ProcessingLoop); + } + + ~ThreadLimitedUdpConnectionListener() + { + this.Dispose(false); + } + + // This is just for booting people after they've been connected a certain amount of time... + public void DisconnectOldConnections(TimeSpan maxAge, MessageWriter disconnectMessage) + { + var now = DateTime.UtcNow; + foreach (var conn in this.allConnections.Values) + { + if (now - conn.CreationTime > maxAge) + { + conn.Disconnect("Stale Connection", disconnectMessage); + } + } + } + + private void ManageReliablePackets() + { + while (this.isActive) + { + foreach (var kvp in this.allConnections) + { + var sock = kvp.Value; + sock.ManageReliablePackets(); + } + + Thread.Sleep(100); + } + } + + public override void Start() + { + try + { + socket.Bind(EndPoint); + } + catch (SocketException e) + { + throw new HazelException("Could not start listening as a SocketException occurred", e); + } + + this.isActive = true; + this.reliablePacketThread.Start(); + this.sendThread.Start(); + this.receiveThread.Start(); + this.processThreads.Start(); + } + + private void ReceiveLoop() + { + while (this.isActive) + { + if (this.socket.Poll(1000, SelectMode.SelectRead)) + { + if (!isActive) break; + + EndPoint remoteEP = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port); + var message = MessageReader.GetSized(this.ReceiveBufferSize); + try + { + message.Length = socket.ReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP); + } + catch (SocketException sx) + { + message.Recycle(); + if (sx.SocketErrorCode == SocketError.NotConnected) + { + this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected); + return; + } + + this.Logger.WriteError("Socket Ex in ReceiveLoop: " + sx.Message); + continue; + } + catch (Exception ex) + { + message.Recycle(); + this.Logger.WriteError("Stopped due to: " + ex.Message); + return; + } + + ConnectionId connectionId = ConnectionId.Create((IPEndPoint)remoteEP, 0); + this.ProcessIncomingMessageFromOtherThread(message, (IPEndPoint)remoteEP, connectionId); + } + } + } + + private void ProcessingLoop() + { + foreach (ReceiveMessageInfo msg in this.receiveQueue.GetConsumingEnumerable()) + { + try + { + this.ReadCallback(msg.Message, msg.Sender, msg.ConnectionId); + } + catch + { + + } + } + } + + protected void ProcessIncomingMessageFromOtherThread(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) + { + var info = new ReceiveMessageInfo() { Message = message, Sender = remoteEndPoint, ConnectionId = connectionId }; + if (!this.receiveQueue.TryAdd(info)) + { + this.Statistics.AddReceiveThreadBlocking(); + this.receiveQueue.Add(info); + } + } + + private void SendLoop() + { + foreach (SendMessageInfo msg in this.sendQueue.GetConsumingEnumerable()) + { + try + { + if (this.socket.Poll(Timeout.Infinite, SelectMode.SelectWrite)) + { + this.socket.SendTo(msg.Span.GetUnderlyingArray(), msg.Span.Offset, msg.Span.Length, SocketFlags.None, msg.Recipient); + this.Statistics.AddBytesSent(msg.Span.Length - msg.Span.Offset); + } + else + { + this.Logger.WriteError("Socket is no longer able to send"); + break; + } + } + catch (Exception e) + { + this.Logger.WriteError("Error in loop while sending: " + e.Message); + Thread.Sleep(1); + } + } + } + + protected virtual void ReadCallback(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) + { + int bytesReceived = message.Length; + bool aware = true; + bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello; + + // If we're aware of this connection use the one already + // If this is a new client then connect with them! + ThreadLimitedUdpServerConnection connection; + if (!this.allConnections.TryGetValue(connectionId, out connection)) + { + lock (this.allConnections) + { + if (!this.allConnections.TryGetValue(connectionId, out connection)) + { + // Check for malformed connection attempts + if (!isHello) + { + message.Recycle(); + return; + } + + if (AcceptConnection != null) + { + if (!AcceptConnection(remoteEndPoint, message.Buffer, out var response)) + { + message.Recycle(); + if (response != null) + { + SendDataRaw(response, remoteEndPoint); + } + + return; + } + } + + aware = false; + connection = new ThreadLimitedUdpServerConnection(this, connectionId, remoteEndPoint, this.IPMode, this.Logger); + if (!this.allConnections.TryAdd(connectionId, connection)) + { + throw new HazelException("Failed to add a connection. This should never happen."); + } + } + } + } + + // If it's a new connection invoke the NewConnection event. + // This needs to happen before handling the message because in localhost scenarios, the ACK and + // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers + if (!aware) + { + // Skip header and hello byte; + message.Offset = 4; + message.Length = bytesReceived - 4; + message.Position = 0; + try + { + this.InvokeNewConnection(message, connection); + } + catch (Exception e) + { + this.Logger.WriteError("NewConnection handler threw: " + e); + } + } + + // Inform the connection of the buffer (new connections need to send an ack back to client) + connection.HandleReceive(message, bytesReceived); + } + + internal void SendDataRaw(byte[] response, IPEndPoint remoteEndPoint) + { + QueueRawData(response, remoteEndPoint); + } + + protected virtual void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint) + { + this.sendQueue.TryAdd(new SendMessageInfo() { Span = span, Recipient = remoteEndPoint }); + } + + /// <summary> + /// Removes a virtual connection from the list. + /// </summary> + /// <param name="endPoint">Connection key of the virtual connection.</param> + internal bool RemoveConnectionTo(ConnectionId connectionId) + { + return this.allConnections.TryRemove(connectionId, out _); + } + + /// <summary> + /// This is after all messages could be sent. Clean up anything extra. + /// </summary> + internal virtual void RemovePeerRecord(ConnectionId connectionId) + { + } + + protected override void Dispose(bool disposing) + { + foreach (var kvp in this.allConnections) + { + kvp.Value.Dispose(); + } + + bool wasActive = this.isActive; + this.isActive = false; + + // Flush outgoing packets + this.sendQueue?.CompleteAdding(); + + if (wasActive) + { + this.sendThread.Join(); + } + + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + + this.receiveQueue?.CompleteAdding(); + + if (wasActive) + { + this.reliablePacketThread.Join(); + this.receiveThread.Join(); + this.processThreads.Join(); + } + + this.receiveQueue?.Dispose(); + this.receiveQueue = null; + this.sendQueue?.Dispose(); + this.sendQueue = null; + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs new file mode 100644 index 0000000..bb139c7 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs @@ -0,0 +1,110 @@ +using System; +using System.Net; + +namespace Hazel.Udp.FewerThreads +{ + /// <summary> + /// Represents a servers's connection to a client that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + public sealed class ThreadLimitedUdpServerConnection : UdpConnection + { + public readonly DateTime CreationTime = DateTime.UtcNow; + + /// <summary> + /// The connection listener that we use the socket of. + /// </summary> + /// <remarks> + /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that + /// created this connection and is hence the listener this conenction sends and receives via. + /// </remarks> + public ThreadLimitedUdpConnectionListener Listener { get; private set; } + + public ThreadLimitedUdpConnectionListener.ConnectionId ConnectionId { get; private set; } + + /// <summary> + /// Creates a UdpConnection for the virtual connection to the endpoint. + /// </summary> + /// <param name="listener">The listener that created this connection.</param> + /// <param name="endPoint">The endpoint that we are connected to.</param> + /// <param name="IPMode">The IPMode we are connected using.</param> + internal ThreadLimitedUdpServerConnection(ThreadLimitedUdpConnectionListener listener, ThreadLimitedUdpConnectionListener.ConnectionId connectionId, IPEndPoint endPoint, IPMode IPMode, ILogger logger) + : base(logger) + { + this.Listener = listener; + this.ConnectionId = connectionId; + this.EndPoint = endPoint; + this.IPMode = IPMode; + + State = ConnectionState.Connected; + this.InitializeKeepAliveTimer(); + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { + if (bytes.Length != length) throw new ArgumentException("I made an assumption here. I hope you see this error."); + + // Hrm, well this is inaccurate for DTLS connections because the Listener does the encryption which may change the size. + // but I don't want to have a bunch of client references in the send queue... + // Does this perhaps mean the encryption is being done in the wrong class? + this.Statistics.LogPacketSend(length); + Listener.SendDataRaw(bytes, EndPoint); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void Connect(byte[] bytes = null, int timeout = 5000) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void ConnectAsync(byte[] bytes = null) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected override bool SendDisconnect(MessageWriter data = null) + { + if (!Listener.RemoveConnectionTo(this.ConnectionId)) return false; + this._state = ConnectionState.NotConnected; + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + this.WriteBytesToConnection(bytes, bytes.Length); + } + catch { } + + return true; + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + SendDisconnect(); + } + + Listener.RemovePeerRecord(this.ConnectionId); + base.Dispose(disposing); + } + } +} |