diff options
Diffstat (limited to 'Impostor-dev/src/Impostor.Hazel')
28 files changed, 3905 insertions, 0 deletions
diff --git a/Impostor-dev/src/Impostor.Hazel/Connection.cs b/Impostor-dev/src/Impostor.Hazel/Connection.cs new file mode 100644 index 0000000..dec8cfe --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Connection.cs @@ -0,0 +1,249 @@ +using System; +using System.Net; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; +using Serilog; + +namespace Impostor.Hazel +{ + /// <summary> + /// Base class for all connections. + /// </summary> + /// <remarks> + /// <para> + /// Connection is the base class for all connections that Hazel can make. It provides common functionality and a + /// standard interface to allow connections to be swapped easily. + /// </para> + /// <para> + /// Any class inheriting from Connection should provide the 3 standard guarantees that Hazel provides: + /// <list type="bullet"> + /// <item> + /// <description>Thread Safe</description> + /// </item> + /// <item> + /// <description>Connection Orientated</description> + /// </item> + /// <item> + /// <description>Packet/Message Based</description> + /// </item> + /// </list> + /// </para> + /// </remarks> + /// <threadsafety static="true" instance="true"/> + public abstract class Connection : IDisposable + { + private static readonly ILogger Logger = Log.ForContext<Connection>(); + + /// <summary> + /// Called when a message has been received. + /// </summary> + /// <remarks> + /// <para> + /// DataReceived is invoked everytime a message is received from the end point of this connection, the message + /// that was received can be found in the <see cref="DataReceivedEventArgs"/> alongside other information from the + /// event. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpClientExample.cs"/> + /// </example> + public Func<DataReceivedEventArgs, ValueTask> DataReceived; + + public int TestLagMs = -1; + public int TestDropRate = 0; + protected int testDropCount = 0; + + /// <summary> + /// Called when the end point disconnects or an error occurs. + /// </summary> + /// <remarks> + /// <para> + /// Disconnected is invoked when the connection is closed due to an exception occuring or because the remote + /// end point disconnected. If it was invoked due to an exception occuring then the exception is available + /// in the <see cref="DisconnectedEventArgs"/> passed with the event. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpClientExample.cs"/> + /// </example> + public Func<DisconnectedEventArgs, ValueTask> Disconnected; + + /// <summary> + /// The remote end point of this Connection. + /// </summary> + /// <remarks> + /// This is the end point that this connection is connected to (i.e. the other device). This returns an abstract + /// <see cref="ConnectionEndPoint"/> which can then be cast to an appropriate end point depending on the + /// connection type. + /// </remarks> + public IPEndPoint EndPoint { get; protected set; } + + public IPMode IPMode { get; protected set; } + + /// <summary> + /// The traffic statistics about this Connection. + /// </summary> + /// <remarks> + /// Contains statistics about the number of messages and bytes sent and received by this connection. + /// </remarks> + public ConnectionStatistics Statistics { get; protected set; } + + /// <summary> + /// The state of this connection. + /// </summary> + /// <remarks> + /// All implementers should be aware that when this is set to ConnectionState.Connected it will + /// release all threads that are blocked on <see cref="WaitOnConnect"/>. + /// </remarks> + public ConnectionState State + { + get + { + return this._state; + } + + protected set + { + this._state = value; + this.SetState(value); + } + } + + protected ConnectionState _state; + protected virtual void SetState(ConnectionState state) { } + + /// <summary> + /// Constructor that initializes the ConnecitonStatistics object. + /// </summary> + /// <remarks> + /// This constructor initialises <see cref="Statistics"/> with empty statistics and sets <see cref="State"/> to + /// <see cref="ConnectionState.NotConnected"/>. + /// </remarks> + protected Connection() + { + this.Statistics = new ConnectionStatistics(); + this.State = ConnectionState.NotConnected; + } + + /// <summary> + /// Sends a number of bytes to the end point of the connection using the specified <see cref="MessageType"/>. + /// </summary> + /// <param name="msg">The message to send.</param> + /// <remarks> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Connection_SendBytes_General']/*" /> + /// <para> + /// The messageType parameter is only a request to use those options and the actual method used to send the + /// data is up to the implementation. There are circumstances where this parameter may be ignored but in + /// general any implementer should aim to always follow the user's request. + /// </para> + /// </remarks> + public abstract ValueTask SendAsync(IMessageWriter msg); + + /// <summary> + /// Sends a number of bytes to the end point of the connection using the specified <see cref="MessageType"/>. + /// </summary> + /// <param name="bytes">The bytes of the message to send.</param> + /// <param name="messageType">The option specifying how the message should be sent.</param> + /// <remarks> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Connection_SendBytes_General']/*" /> + /// <para> + /// The messageType parameter is only a request to use those options and the actual method used to send the + /// data is up to the implementation. There are circumstances where this parameter may be ignored but in + /// general any implementer should aim to always follow the user's request. + /// </para> + /// </remarks> + public abstract ValueTask SendBytes(byte[] bytes, MessageType messageType = MessageType.Unreliable); + + /// <summary> + /// Connects the connection to a server and begins listening. + /// This method does not block. + /// </summary> + /// <param name="bytes">The bytes of data to send in the handshake.</param> + public abstract ValueTask ConnectAsync(byte[] bytes = null); + + /// <summary> + /// Invokes the DataReceived event. + /// </summary> + /// <param name="msg">The bytes received.</param> + /// <param name="messageType">The <see cref="MessageType"/> the message was received with.</param> + /// <remarks> + /// Invokes the <see cref="DataReceived"/> event on this connection to alert subscribers a new message has been + /// received. The bytes and the send option that the message was sent with should be passed in to give to the + /// subscribers. + /// </remarks> + protected async ValueTask InvokeDataReceived(IMessageReader msg, MessageType messageType) + { + // Make a copy to avoid race condition between null check and invocation + var handler = DataReceived; + if (handler != null) + { + try + { + await handler(new DataReceivedEventArgs(this, msg, messageType)); + } + catch (Exception e) + { + Logger.Error(e, "Invoking data received failed"); + await Disconnect("Invoking data received failed"); + } + } + } + + /// <summary> + /// Invokes the Disconnected event. + /// </summary> + /// <param name="e">The exception, if any, that occurred to cause this.</param> + /// <param name="reader">Extra disconnect data</param> + /// <remarks> + /// Invokes the <see cref="Disconnected"/> event to alert subscribres this connection has been disconnected either + /// by the end point or because an error occurred. If an error occurred the error should be passed in in order to + /// pass to the subscribers, otherwise null can be passed in. + /// </remarks> + protected async ValueTask InvokeDisconnected(string e, IMessageReader reader) + { + // Make a copy to avoid race condition between null check and invocation + var handler = Disconnected; + if (handler != null) + { + try + { + await handler(new DisconnectedEventArgs(e, reader)); + } + catch (Exception ex) + { + Logger.Error(ex, "Error in InvokeDisconnected"); + } + } + } + + /// <summary> + /// For times when you want to force the disconnect handler to fire as well as close it. + /// If you only want to close it, just use Dispose. + /// </summary> + public abstract ValueTask Disconnect(string reason, MessageWriter writer = null); + + /// <summary> + /// Disposes of this NetworkConnection. + /// </summary> + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// <summary> + /// Disposes of this NetworkConnection. + /// </summary> + /// <param name="disposing">Are we currently disposing?</param> + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + this.DataReceived = null; + this.Disconnected = null; + } + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/ConnectionListener.cs b/Impostor-dev/src/Impostor.Hazel/ConnectionListener.cs new file mode 100644 index 0000000..116f657 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/ConnectionListener.cs @@ -0,0 +1,100 @@ +using System; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; +using Serilog; + +namespace Impostor.Hazel +{ + /// <summary> + /// Base class for all connection listeners. + /// </summary> + /// <remarks> + /// <para> + /// ConnectionListeners are server side objects that listen for clients and create matching server side connections + /// for each client in a similar way to TCP does. These connections should be ready for communication immediately. + /// </para> + /// <para> + /// Each time a client connects the <see cref="NewConnection"/> event will be invoked to alert all subscribers to + /// the new connection. A disconnected event is then present on the <see cref="Connection"/> that is passed to the + /// subscribers. + /// </para> + /// </remarks> + /// <threadsafety static="true" instance="true"/> + public abstract class ConnectionListener : IAsyncDisposable + { + private static readonly ILogger Logger = Log.ForContext<ConnectionListener>(); + + /// <summary> + /// Invoked when a new client connects. + /// </summary> + /// <remarks> + /// <para> + /// NewConnection is invoked each time a client connects to the listener. The + /// <see cref="NewConnectionEventArgs"/> contains the new <see cref="Connection"/> for communication with this + /// client. + /// </para> + /// <para> + /// Hazel may or may not store connections so it is your responsibility to keep track and properly Dispose of + /// connections to your server. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/> + /// </example> + public Func<NewConnectionEventArgs, ValueTask> NewConnection; + + /// <summary> + /// Makes this connection listener begin listening for connections. + /// </summary> + /// <remarks> + /// <para> + /// This instructs the listener to begin listening for new clients connecting to the server. When a new client + /// connects the <see cref="NewConnection"/> event will be invoked containing the connection to the new client. + /// </para> + /// <para> + /// To stop listening you should call <see cref="DisposeAsync()"/>. + /// </para> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/> + /// </example> + public abstract Task StartAsync(); + + /// <summary> + /// Invokes the NewConnection event with the supplied connection. + /// </summary> + /// <param name="msg">The user sent bytes that were received as part of the handshake.</param> + /// <param name="connection">The connection to pass in the arguments.</param> + /// <remarks> + /// Implementers should call this to invoke the <see cref="NewConnection"/> event before data is received so that + /// subscribers do not miss any data that may have been sent immediately after connecting. + /// </remarks> + internal async Task InvokeNewConnection(IMessageReader msg, Connection connection) + { + // Make a copy to avoid race condition between null check and invocation + var handler = NewConnection; + if (handler != null) + { + try + { + await handler(new NewConnectionEventArgs(msg, connection)); + } + catch (Exception e) + { + Logger.Error(e, "Accepting connection failed"); + await connection.Disconnect("Accepting connection failed"); + } + } + } + + /// <summary> + /// Call to dispose of the connection listener. + /// </summary> + public virtual ValueTask DisposeAsync() + { + this.NewConnection = null; + return ValueTask.CompletedTask; + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/ConnectionState.cs b/Impostor-dev/src/Impostor.Hazel/ConnectionState.cs new file mode 100644 index 0000000..5dd7c6a --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/ConnectionState.cs @@ -0,0 +1,23 @@ +namespace Impostor.Hazel +{ + /// <summary> + /// Represents the state a <see cref="Connection"/> is currently in. + /// </summary> + public enum ConnectionState + { + /// <summary> + /// The Connection has either not been established yet or has been disconnected. + /// </summary> + NotConnected, + + /// <summary> + /// The Connection is currently connecting to an endpoint. + /// </summary> + Connecting, + + /// <summary> + /// The Connection is connected and data can be transfered. + /// </summary> + Connected, + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/ConnectionStatistics.cs b/Impostor-dev/src/Impostor.Hazel/ConnectionStatistics.cs new file mode 100644 index 0000000..4802620 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/ConnectionStatistics.cs @@ -0,0 +1,566 @@ +using System.Runtime.CompilerServices; +using System.Threading; + +[assembly: InternalsVisibleTo("Hazel.Tests")] +namespace Impostor.Hazel +{ + /// <summary> + /// Holds statistics about the traffic through a <see cref="Connection"/>. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public class ConnectionStatistics + { + private const int ExpectedMTU = 1200; + + /// <summary> + /// The total number of messages sent. + /// </summary> + public int MessagesSent + { + get + { + return UnreliableMessagesSent + ReliableMessagesSent + FragmentedMessagesSent + AcknowledgementMessagesSent + HelloMessagesSent; + } + } + + /// <summary> + /// The number of messages sent larger than 576 bytes. This is smaller than most default MTUs. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int FragmentableMessagesSent + { + get + { + return fragmentableMessagesSent; + } + } + + /// <summary> + /// The number of messages sent larger than 576 bytes. + /// </summary> + int fragmentableMessagesSent; + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int UnreliableMessagesSent + { + get + { + return unreliableMessagesSent; + } + } + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + int unreliableMessagesSent; + + /// <summary> + /// The number of reliable messages sent. + /// </summary> + /// <remarks> + /// This is the number of reliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogReliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int ReliableMessagesSent + { + get + { + return reliableMessagesSent; + } + } + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + int reliableMessagesSent; + + /// <summary> + /// The number of fragmented messages sent. + /// </summary> + /// <remarks> + /// This is the number of fragmented messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogFragmentedSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int FragmentedMessagesSent + { + get + { + return fragmentedMessagesSent; + } + } + + /// <summary> + /// The number of fragmented messages sent. + /// </summary> + int fragmentedMessagesSent; + + /// <summary> + /// The number of acknowledgement messages sent. + /// </summary> + /// <remarks> + /// This is the number of acknowledgements that were sent from the <see cref="Connection"/>, incremented + /// each time that LogAcknowledgementSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int AcknowledgementMessagesSent + { + get + { + return acknowledgementMessagesSent; + } + } + + /// <summary> + /// The number of acknowledgement messages sent. + /// </summary> + int acknowledgementMessagesSent; + + /// <summary> + /// The number of hello messages sent. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogHelloSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int HelloMessagesSent + { + get + { + return helloMessagesSent; + } + } + + /// <summary> + /// The number of hello messages sent. + /// </summary> + int helloMessagesSent; + + /// <summary> + /// The number of bytes of data sent. + /// </summary> + /// <remarks> + /// <para> + /// This is the number of bytes of data (i.e. user bytes) that were sent from the <see cref="Connection"/>, + /// accumulated each time that LogSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </para> + /// <para> + /// For the number of bytes including protocol bytes see <see cref="TotalBytesSent"/>. + /// </para> + /// </remarks> + public long DataBytesSent + { + get + { + return Interlocked.Read(ref dataBytesSent); + } + } + + /// <summary> + /// The number of bytes of data sent. + /// </summary> + long dataBytesSent; + + /// <summary> + /// The number of bytes sent in total. + /// </summary> + /// <remarks> + /// <para> + /// This is the total number of bytes (the data bytes plus protocol bytes) that were sent from the + /// <see cref="Connection"/>, accumulated each time that LogSend is called by the Connection. Messages that + /// caused an error are not counted and messages are only counted once all other operations in the send are + /// complete. + /// </para> + /// <para> + /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesSent"/>. + /// </para> + /// </remarks> + public long TotalBytesSent + { + get + { + return Interlocked.Read(ref totalBytesSent); + } + } + + /// <summary> + /// The number of bytes sent in total. + /// </summary> + long totalBytesSent; + + /// <summary> + /// The total number of messages received. + /// </summary> + public int MessagesReceived + { + get + { + return UnreliableMessagesReceived + ReliableMessagesReceived + FragmentedMessagesReceived + AcknowledgementMessagesReceived + helloMessagesReceived; + } + } + + /// <summary> + /// The number of unreliable messages received. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogUnreliableReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int UnreliableMessagesReceived + { + get + { + return unreliableMessagesReceived; + } + } + + /// <summary> + /// The number of unreliable messages received. + /// </summary> + int unreliableMessagesReceived; + + /// <summary> + /// The number of reliable messages received. + /// </summary> + /// <remarks> + /// This is the number of reliable messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogReliableReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int ReliableMessagesReceived + { + get + { + return reliableMessagesReceived; + } + } + + /// <summary> + /// The number of reliable messages received. + /// </summary> + int reliableMessagesReceived; + + /// <summary> + /// The number of fragmented messages received. + /// </summary> + /// <remarks> + /// This is the number of fragmented messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogFragmentedReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int FragmentedMessagesReceived + { + get + { + return fragmentedMessagesReceived; + } + } + + /// <summary> + /// The number of fragmented messages received. + /// </summary> + int fragmentedMessagesReceived; + + /// <summary> + /// The number of acknowledgement messages received. + /// </summary> + /// <remarks> + /// This is the number of acknowledgement messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogAcknowledgemntReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int AcknowledgementMessagesReceived + { + get + { + return acknowledgementMessagesReceived; + } + } + + /// <summary> + /// The number of acknowledgement messages received. + /// </summary> + int acknowledgementMessagesReceived; + + /// <summary> + /// The number of ping messages received. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int PingMessagesReceived + { + get + { + return pingMessagesReceived; + } + } + + /// <summary> + /// The number of hello messages received. + /// </summary> + int pingMessagesReceived; + + /// <summary> + /// The number of hello messages received. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int HelloMessagesReceived + { + get + { + return helloMessagesReceived; + } + } + + /// <summary> + /// The number of hello messages received. + /// </summary> + int helloMessagesReceived; + + /// <summary> + /// The number of bytes of data received. + /// </summary> + /// <remarks> + /// <para> + /// This is the number of bytes of data (i.e. user bytes) that were received by the <see cref="Connection"/>, + /// accumulated each time that LogReceive is called by the Connection. Messages are counted before the receive + /// event is invoked. + /// </para> + /// <para> + /// For the number of bytes including protocol bytes see <see cref="TotalBytesReceived"/>. + /// </para> + /// </remarks> + public long DataBytesReceived + { + get + { + return Interlocked.Read(ref dataBytesReceived); + } + } + + /// <summary> + /// The number of bytes of data received. + /// </summary> + long dataBytesReceived; + + /// <summary> + /// The number of bytes received in total. + /// </summary> + /// <remarks> + /// <para> + /// This is the total number of bytes (the data bytes plus protocol bytes) that were received by the + /// <see cref="Connection"/>, accumulated each time that LogReceive is called by the Connection. Messages are + /// counted before the receive event is invoked. + /// </para> + /// <para> + /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesReceived"/>. + /// </para> + /// </remarks> + public long TotalBytesReceived + { + get + { + return Interlocked.Read(ref totalBytesReceived); + } + } + + /// <summary> + /// The number of bytes received in total. + /// </summary> + long totalBytesReceived; + + public int MessagesResent { get { return messagesResent; } } + int messagesResent; + + /// <summary> + /// Logs the sending of an unreliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogUnreliableSend(int dataLength, int totalLength) + { + Interlocked.Increment(ref unreliableMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + Interlocked.Add(ref totalBytesSent, totalLength); + + if (totalLength > ExpectedMTU) + { + Interlocked.Increment(ref fragmentableMessagesSent); + } + } + + /// <summary> + /// Logs the sending of a reliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogReliableSend(int dataLength, int totalLength) + { + Interlocked.Increment(ref reliableMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + Interlocked.Add(ref totalBytesSent, totalLength); + + if (totalLength > ExpectedMTU) + { + Interlocked.Increment(ref fragmentableMessagesSent); + } + } + + /// <summary> + /// Logs the sending of a fragmented data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogFragmentedSend(int dataLength, int totalLength) + { + Interlocked.Increment(ref fragmentedMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + Interlocked.Add(ref totalBytesSent, totalLength); + + if (totalLength > ExpectedMTU) + { + Interlocked.Increment(ref fragmentableMessagesSent); + } + } + + /// <summary> + /// Logs the sending of a acknowledgement data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogAcknowledgementSend(int totalLength) + { + Interlocked.Increment(ref acknowledgementMessagesSent); + Interlocked.Add(ref totalBytesSent, totalLength); + } + + /// <summary> + /// Logs the sending of a hellp data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogHelloSend(int totalLength) + { + Interlocked.Increment(ref helloMessagesSent); + Interlocked.Add(ref totalBytesSent, totalLength); + } + + /// <summary> + /// Logs the receiving of an unreliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogUnreliableReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref unreliableMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a reliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogReliableReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref reliableMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a fragmented data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogFragmentedReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref fragmentedMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of an acknowledgement data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogAcknowledgementReceive(int totalLength) + { + Interlocked.Increment(ref acknowledgementMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a hello data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogPingReceive(int totalLength) + { + Interlocked.Increment(ref pingMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a hello data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogHelloReceive(int totalLength) + { + Interlocked.Increment(ref helloMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + internal void LogMessageResent() + { + Interlocked.Increment(ref messagesResent); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/DataReceivedEventArgs.cs b/Impostor-dev/src/Impostor.Hazel/DataReceivedEventArgs.cs new file mode 100644 index 0000000..9176d8d --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/DataReceivedEventArgs.cs @@ -0,0 +1,26 @@ +using Impostor.Api.Net.Messages; + +namespace Impostor.Hazel +{ + public struct DataReceivedEventArgs + { + public readonly Connection Sender; + + /// <summary> + /// The bytes received from the client. + /// </summary> + public readonly IMessageReader Message; + + /// <summary> + /// The <see cref="Type"/> the data was sent with. + /// </summary> + public readonly MessageType Type; + + public DataReceivedEventArgs(Connection sender, IMessageReader msg, MessageType type) + { + this.Sender = sender; + this.Message = msg; + this.Type = type; + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/DisconnectedEventArgs.cs b/Impostor-dev/src/Impostor.Hazel/DisconnectedEventArgs.cs new file mode 100644 index 0000000..d46df4b --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/DisconnectedEventArgs.cs @@ -0,0 +1,25 @@ +using System; +using Impostor.Api.Net.Messages; + +namespace Impostor.Hazel +{ + public class DisconnectedEventArgs : EventArgs + { + /// <summary> + /// Optional disconnect reason. May be null. + /// </summary> + public readonly string Reason; + + /// <summary> + /// Optional data sent with a disconnect message. May be null. + /// You must not recycle this. If you need the message outside of a callback, you should copy it. + /// </summary> + public readonly IMessageReader Message; + + public DisconnectedEventArgs(string reason, IMessageReader message) + { + this.Reason = reason; + this.Message = message; + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Extensions/ServiceProviderExtensions.cs b/Impostor-dev/src/Impostor.Hazel/Extensions/ServiceProviderExtensions.cs new file mode 100644 index 0000000..56c7380 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Extensions/ServiceProviderExtensions.cs @@ -0,0 +1,21 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.ObjectPool; + +namespace Impostor.Hazel.Extensions +{ + public static class ServiceProviderExtensions + { + public static void AddHazel(this IServiceCollection services) + { + services.TryAddSingleton<ObjectPoolProvider>(new DefaultObjectPoolProvider()); + + services.AddSingleton(serviceProvider => + { + var provider = serviceProvider.GetRequiredService<ObjectPoolProvider>(); + var policy = ActivatorUtilities.CreateInstance<MessageReaderPolicy>(serviceProvider); + return provider.Create(policy); + }); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/HazelException.cs b/Impostor-dev/src/Impostor.Hazel/HazelException.cs new file mode 100644 index 0000000..8c6fc3c --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/HazelException.cs @@ -0,0 +1,21 @@ +using System; + +namespace Impostor.Hazel +{ + /// <summary> + /// Wrapper for exceptions thrown from Hazel. + /// </summary> + [Serializable] + public class HazelException : Exception + { + internal HazelException(string msg) : base (msg) + { + + } + + internal HazelException(string msg, Exception e) : base (msg, e) + { + + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/IPMode.cs b/Impostor-dev/src/Impostor.Hazel/IPMode.cs new file mode 100644 index 0000000..5eb6679 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/IPMode.cs @@ -0,0 +1,24 @@ +namespace Impostor.Hazel +{ + /// <summary> + /// Represents the IP version that a connection or listener will use. + /// </summary> + /// <remarks> + /// If you wand a client to connect or be able to connect using IPv6 then you should use <see cref="IPv4AndIPv6"/>, + /// this sets the underlying sockets to use IPv6 but still allow IPv4 sockets to connect for backwards compatability + /// and hence it is the default IPMode in most cases. + /// </remarks> + public enum IPMode + { + /// <summary> + /// Instruction to use IPv4 only, IPv6 connections will not be able to connect. + /// </summary> + IPv4, + + /// <summary> + /// Instruction to use IPv6 only, IPv4 connections will not be able to connect. IPv4 addresses can be connected + /// by converting to IPv6 addresses. + /// </summary> + IPv6 + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/IRecyclable.cs b/Impostor-dev/src/Impostor.Hazel/IRecyclable.cs new file mode 100644 index 0000000..69be122 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/IRecyclable.cs @@ -0,0 +1,24 @@ +namespace Impostor.Hazel +{ + /// <summary> + /// Interface for all items that can be returned to an object pool. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public interface IRecyclable + { + /// <summary> + /// Returns this object back to the object pool. + /// </summary> + /// <remarks> + /// <para> + /// Calling this when you are done with the object returns the object back to a pool in order to be reused. + /// This can reduce the amount of work the GC has to do dramatically but it is optional to call this. + /// </para> + /// <para> + /// Calling this indicates to Hazel that this can be reused and thus you should only call this when you are + /// completely finished with the object as the contents can be overwritten at any point after. + /// </para> + /// </remarks> + void Recycle(); + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Impostor.Hazel.csproj b/Impostor-dev/src/Impostor.Hazel/Impostor.Hazel.csproj new file mode 100644 index 0000000..3e035fb --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Impostor.Hazel.csproj @@ -0,0 +1,19 @@ +<Project Sdk="Microsoft.NET.Sdk"> + + <PropertyGroup> + <AllowUnsafeBlocks>true</AllowUnsafeBlocks> + <TargetFramework>net5.0</TargetFramework> + <DefineConstants>HAZEL_BAG</DefineConstants> + <Version>1.0.0</Version> + </PropertyGroup> + + <ItemGroup> + <PackageReference Include="Microsoft.Extensions.ObjectPool" Version="5.0.0" /> + <PackageReference Include="Serilog" Version="2.10.0" /> + </ItemGroup> + + <ItemGroup> + <ProjectReference Include="..\Impostor.Api\Impostor.Api.csproj" /> + </ItemGroup> + +</Project> diff --git a/Impostor-dev/src/Impostor.Hazel/MessageReader.cs b/Impostor-dev/src/Impostor.Hazel/MessageReader.cs new file mode 100644 index 0000000..986d0b0 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/MessageReader.cs @@ -0,0 +1,256 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Runtime.CompilerServices; +using System.Text; +using Impostor.Api; +using Impostor.Api.Net.Messages; +using Microsoft.Extensions.ObjectPool; + +namespace Impostor.Hazel +{ + public class MessageReader : IMessageReader + { + private static readonly ArrayPool<byte> ArrayPool = ArrayPool<byte>.Shared; + + private readonly ObjectPool<MessageReader> _pool; + private bool _inUse; + + internal MessageReader(ObjectPool<MessageReader> pool) + { + _pool = pool; + } + + public byte[] Buffer { get; private set; } + + public int Offset { get; internal set; } + + public int Position { get; internal set; } + + public int Length { get; internal set; } + + public byte Tag { get; private set; } + + public MessageReader Parent { get; private set; } + + private int ReadPosition => Offset + Position; + + public void Update(byte[] buffer, int offset = 0, int position = 0, int? length = null, byte tag = byte.MaxValue, MessageReader parent = null) + { + _inUse = true; + + Buffer = buffer; + Offset = offset; + Position = position; + Length = length ?? buffer.Length; + Tag = tag; + Parent = parent; + } + + internal void Reset() + { + _inUse = false; + + Tag = byte.MaxValue; + Buffer = null; + Offset = 0; + Position = 0; + Length = 0; + Parent = null; + } + + public IMessageReader ReadMessage() + { + var length = ReadUInt16(); + var tag = FastByte(); + var pos = ReadPosition; + + Position += length; + + var reader = _pool.Get(); + reader.Update(Buffer, pos, 0, length, tag, this); + return reader; + } + + public bool ReadBoolean() + { + byte val = FastByte(); + return val != 0; + } + + public sbyte ReadSByte() + { + return (sbyte)FastByte(); + } + + public byte ReadByte() + { + return FastByte(); + } + + public ushort ReadUInt16() + { + var output = BinaryPrimitives.ReadUInt16LittleEndian(Buffer.AsSpan(ReadPosition)); + Position += sizeof(ushort); + return output; + } + + public short ReadInt16() + { + var output = BinaryPrimitives.ReadInt16LittleEndian(Buffer.AsSpan(ReadPosition)); + Position += sizeof(short); + return output; + } + + public uint ReadUInt32() + { + var output = BinaryPrimitives.ReadUInt32LittleEndian(Buffer.AsSpan(ReadPosition)); + Position += sizeof(uint); + return output; + } + + public int ReadInt32() + { + var output = BinaryPrimitives.ReadInt32LittleEndian(Buffer.AsSpan(ReadPosition)); + Position += sizeof(int); + return output; + } + + public unsafe float ReadSingle() + { + var output = BinaryPrimitives.ReadSingleLittleEndian(Buffer.AsSpan(ReadPosition)); + Position += sizeof(float); + return output; + } + + public string ReadString() + { + var len = ReadPackedInt32(); + var output = Encoding.UTF8.GetString(Buffer.AsSpan(ReadPosition, len)); + Position += len; + return output; + } + + public ReadOnlyMemory<byte> ReadBytesAndSize() + { + var len = ReadPackedInt32(); + return ReadBytes(len); + } + + public ReadOnlyMemory<byte> ReadBytes(int length) + { + var output = Buffer.AsMemory(ReadPosition, length); + Position += length; + return output; + } + + public int ReadPackedInt32() + { + return (int)ReadPackedUInt32(); + } + + public uint ReadPackedUInt32() + { + bool readMore = true; + int shift = 0; + uint output = 0; + + while (readMore) + { + byte b = FastByte(); + if (b >= 0x80) + { + readMore = true; + b ^= 0x80; + } + else + { + readMore = false; + } + + output |= (uint)(b << shift); + shift += 7; + } + + return output; + } + + public void CopyTo(IMessageWriter writer) + { + writer.Write((ushort) Length); + writer.Write((byte) Tag); + writer.Write(Buffer.AsMemory(Offset, Length)); + } + + public void Seek(int position) + { + Position = position; + } + + public void RemoveMessage(IMessageReader message) + { + if (message.Buffer != Buffer) + { + throw new ImpostorProtocolException("Tried to remove message from a message that does not have the same buffer."); + } + + // Offset of where to start removing. + var offsetStart = message.Offset - 3; + + // Offset of where to end removing. + var offsetEnd = message.Offset + message.Length; + + // The amount of bytes to copy over ourselves. + var lengthToCopy = message.Buffer.Length - offsetEnd; + + System.Buffer.BlockCopy(Buffer, offsetEnd, Buffer, offsetStart, lengthToCopy); + + ((MessageReader) message).Parent.AdjustLength(message.Offset, message.Length + 3); + } + + private void AdjustLength(int offset, int amount) + { + this.Length -= amount; + + if (this.ReadPosition > offset) + { + this.Position -= amount; + } + + if (Parent != null) + { + var lengthOffset = this.Offset - 3; + var curLen = this.Buffer[lengthOffset] | + (this.Buffer[lengthOffset + 1] << 8); + + curLen -= amount; + + this.Buffer[lengthOffset] = (byte)curLen; + this.Buffer[lengthOffset + 1] = (byte)(this.Buffer[lengthOffset + 1] >> 8); + + Parent.AdjustLength(offset, amount); + } + } + + public IMessageReader Copy(int offset = 0) + { + var reader = _pool.Get(); + reader.Update(Buffer, Offset + offset, Position, Length - offset, Tag, Parent); + return reader; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private byte FastByte() + { + return Buffer[Offset + Position++]; + } + + public void Dispose() + { + if (_inUse) + { + _pool.Return(this); + } + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/MessageReaderPolicy.cs b/Impostor-dev/src/Impostor.Hazel/MessageReaderPolicy.cs new file mode 100644 index 0000000..ef3939a --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/MessageReaderPolicy.cs @@ -0,0 +1,27 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.ObjectPool; + +namespace Impostor.Hazel +{ + public class MessageReaderPolicy : IPooledObjectPolicy<MessageReader> + { + private readonly IServiceProvider _serviceProvider; + + public MessageReaderPolicy(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public MessageReader Create() + { + return new MessageReader(_serviceProvider.GetRequiredService<ObjectPool<MessageReader>>()); + } + + public bool Return(MessageReader obj) + { + obj.Reset(); + return true; + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/MessageWriter.cs b/Impostor-dev/src/Impostor.Hazel/MessageWriter.cs new file mode 100644 index 0000000..5b7342a --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/MessageWriter.cs @@ -0,0 +1,335 @@ +using Impostor.Api.Games; +using Impostor.Api.Net.Messages; + +using System; +using System.Collections.Generic; +using System.Net; +using System.Text; + +namespace Impostor.Hazel +{ + public class MessageWriter : IMessageWriter, IRecyclable, IDisposable + { + private static int BufferSize = 64000; + private static readonly ObjectPoolCustom<MessageWriter> WriterPool = new ObjectPoolCustom<MessageWriter>(() => new MessageWriter(BufferSize)); + + public MessageType SendOption { get; private set; } + + private Stack<int> messageStarts = new Stack<int>(); + + public MessageWriter(byte[] buffer) + { + this.Buffer = buffer; + this.Length = this.Buffer.Length; + } + + public MessageWriter(int bufferSize) + { + this.Buffer = new byte[bufferSize]; + } + + public byte[] Buffer { get; } + public int Length { get; set; } + public int Position { get; set; } + + public byte[] ToByteArray(bool includeHeader) + { + if (includeHeader) + { + byte[] output = new byte[this.Length]; + System.Buffer.BlockCopy(this.Buffer, 0, output, 0, this.Length); + return output; + } + else + { + switch (this.SendOption) + { + case MessageType.Reliable: + { + byte[] output = new byte[this.Length - 3]; + System.Buffer.BlockCopy(this.Buffer, 3, output, 0, this.Length - 3); + return output; + } + case MessageType.Unreliable: + { + byte[] output = new byte[this.Length - 1]; + System.Buffer.BlockCopy(this.Buffer, 1, output, 0, this.Length - 1); + return output; + } + default: + throw new ArgumentOutOfRangeException(); + } + } + + throw new NotImplementedException(); + } + + /// + /// <param name="sendOption">The option specifying how the message should be sent.</param> + public static MessageWriter Get(MessageType sendOption = MessageType.Unreliable) + { + var output = WriterPool.GetObject(); + output.Clear(sendOption); + + return output; + } + + public bool HasBytes(int expected) + { + if (this.SendOption == MessageType.Unreliable) + { + return this.Length > 1 + expected; + } + + return this.Length > 3 + expected; + } + + public void Write(GameCode value) + { + this.Write(value.Value); + } + + /// + public void StartMessage(byte typeFlag) + { + messageStarts.Push(this.Position); + this.Position += 2; // Skip for size + this.Write(typeFlag); + } + + /// + public void EndMessage() + { + var lastMessageStart = messageStarts.Pop(); + ushort length = (ushort)(this.Position - lastMessageStart - 3); // Minus length and type byte + this.Buffer[lastMessageStart] = (byte)length; + this.Buffer[lastMessageStart + 1] = (byte)(length >> 8); + } + + /// + public void CancelMessage() + { + this.Position = this.messageStarts.Pop(); + this.Length = this.Position; + } + + public void Clear(MessageType sendOption) + { + this.messageStarts.Clear(); + this.SendOption = sendOption; + this.Buffer[0] = (byte)sendOption; + switch (sendOption) + { + default: + case MessageType.Unreliable: + this.Length = this.Position = 1; + break; + + case MessageType.Reliable: + this.Length = this.Position = 3; + break; + } + } + + /// + public void Recycle() + { + this.Position = this.Length = 0; + WriterPool.PutObject(this); + } + + #region WriteMethods + + public void Write(bool value) + { + this.Buffer[this.Position++] = (byte)(value ? 1 : 0); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(sbyte value) + { + this.Buffer[this.Position++] = (byte)value; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte value) + { + this.Buffer[this.Position++] = value; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(short value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(ushort value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(uint value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(int value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + if (this.Position > this.Length) this.Length = this.Position; + } + + public unsafe void Write(float value) + { + fixed (byte* ptr = &this.Buffer[this.Position]) + { + byte* valuePtr = (byte*)&value; + + *ptr = *valuePtr; + *(ptr + 1) = *(valuePtr + 1); + *(ptr + 2) = *(valuePtr + 2); + *(ptr + 3) = *(valuePtr + 3); + } + + this.Position += 4; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(string value) + { + var bytes = UTF8Encoding.UTF8.GetBytes(value); + this.WritePacked(bytes.Length); + this.Write(bytes); + } + + public void Write(IPAddress value) + { + this.Write(value.GetAddressBytes()); + } + + public void WriteBytesAndSize(byte[] bytes) + { + this.WritePacked((uint)bytes.Length); + this.Write(bytes); + } + + public void WriteBytesAndSize(byte[] bytes, int length) + { + this.WritePacked((uint)length); + this.Write(bytes, length); + } + + public void WriteBytesAndSize(byte[] bytes, int offset, int length) + { + this.WritePacked((uint)length); + this.Write(bytes, offset, length); + } + + public void Write(ReadOnlyMemory<byte> data) + { + Write(data.Span); + } + + public void Write(ReadOnlySpan<byte> bytes) + { + bytes.CopyTo(this.Buffer.AsSpan(this.Position, bytes.Length)); + + this.Position += bytes.Length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte[] bytes) + { + Array.Copy(bytes, 0, this.Buffer, this.Position, bytes.Length); + this.Position += bytes.Length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte[] bytes, int offset, int length) + { + Array.Copy(bytes, offset, this.Buffer, this.Position, length); + this.Position += length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte[] bytes, int length) + { + Array.Copy(bytes, 0, this.Buffer, this.Position, length); + this.Position += length; + if (this.Position > this.Length) this.Length = this.Position; + } + + /// + public void WritePacked(int value) + { + this.WritePacked((uint)value); + } + + /// + public void WritePacked(uint value) + { + do + { + byte b = (byte)(value & 0xFF); + if (value >= 0x80) + { + b |= 0x80; + } + + this.Write(b); + value >>= 7; + } while (value > 0); + } + + #endregion WriteMethods + + public void Write(MessageWriter msg, bool includeHeader) + { + int offset = 0; + if (!includeHeader) + { + switch (msg.SendOption) + { + case MessageType.Unreliable: + offset = 1; + break; + + case MessageType.Reliable: + offset = 3; + break; + } + } + + this.Write(msg.Buffer, offset, msg.Length - offset); + } + + public unsafe static bool IsLittleEndian() + { + byte b; + unsafe + { + int i = 1; + byte* bp = (byte*)&i; + b = *bp; + } + + return b == 1; + } + + public void Dispose() + { + Recycle(); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/NetworkConnection.cs b/Impostor-dev/src/Impostor.Hazel/NetworkConnection.cs new file mode 100644 index 0000000..282fe10 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/NetworkConnection.cs @@ -0,0 +1,121 @@ +using System; +using System.Net; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; + +namespace Impostor.Hazel +{ + public enum HazelInternalErrors + { + SocketExceptionSend, + SocketExceptionReceive, + ReceivedZeroBytes, + PingsWithoutResponse, + ReliablePacketWithoutResponse, + ConnectionDisconnected + } + + /// <summary> + /// Abstract base class for a <see cref="Connection"/> to a remote end point via a network protocol like TCP or UDP. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public abstract class NetworkConnection : Connection + { + /// <summary> + /// An event that gives us a chance to send well-formed disconnect messages to clients when an internal disconnect happens. + /// </summary> + public Func<HazelInternalErrors, MessageWriter> OnInternalDisconnect; + + /// <summary> + /// The remote end point of this connection. + /// </summary> + /// <remarks> + /// This is the end point of the other device given as an <see cref="System.Net.EndPoint"/> rather than a generic + /// <see cref="ConnectionEndPoint"/> as the base <see cref="Connection"/> does. + /// </remarks> + public IPEndPoint RemoteEndPoint { get; protected set; } + + public long GetIP4Address() + { + if (IPMode == IPMode.IPv4) + { + return ((IPEndPoint)this.RemoteEndPoint).Address.Address; + } + else + { + var bytes = ((IPEndPoint)this.RemoteEndPoint).Address.GetAddressBytes(); + return BitConverter.ToInt64(bytes, bytes.Length - 8); + } + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected abstract ValueTask<bool> SendDisconnect(MessageWriter writer); + + /// <summary> + /// Called when the socket has been disconnected at the remote host. + /// </summary> + protected async ValueTask DisconnectRemote(string reason, IMessageReader reader) + { + if (await SendDisconnect(null)) + { + try + { + await InvokeDisconnected(reason, reader); + } + catch { } + } + + this.Dispose(); + } + + /// <summary> + /// Called when socket is disconnected internally + /// </summary> + internal async ValueTask DisconnectInternal(HazelInternalErrors error, string reason) + { + var handler = this.OnInternalDisconnect; + if (handler != null) + { + MessageWriter messageToRemote = handler(error); + if (messageToRemote != null) + { + try + { + await Disconnect(reason, messageToRemote); + } + finally + { + messageToRemote.Recycle(); + } + } + else + { + await Disconnect(reason); + } + } + else + { + await Disconnect(reason); + } + } + + /// <summary> + /// Called when the socket has been disconnected locally. + /// </summary> + public override async ValueTask Disconnect(string reason, MessageWriter writer = null) + { + if (await SendDisconnect(writer)) + { + try + { + await InvokeDisconnected(reason, null); + } + catch { } + } + + this.Dispose(); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/NetworkConnectionListener.cs b/Impostor-dev/src/Impostor.Hazel/NetworkConnectionListener.cs new file mode 100644 index 0000000..e1d7ffa --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/NetworkConnectionListener.cs @@ -0,0 +1,21 @@ +using System.Net; + +namespace Impostor.Hazel +{ + /// <summary> + /// Abstract base class for a <see cref="ConnectionListener"/> for network based connections. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public abstract class NetworkConnectionListener : ConnectionListener + { + /// <summary> + /// The local end point the listener is listening for new clients on. + /// </summary> + public IPEndPoint EndPoint { get; protected set; } + + /// <summary> + /// The <see cref="IPMode">IPMode</see> the listener is listening for new clients on. + /// </summary> + public IPMode IPMode { get; protected set; } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/NewConnectionEventArgs.cs b/Impostor-dev/src/Impostor.Hazel/NewConnectionEventArgs.cs new file mode 100644 index 0000000..be9e7a2 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/NewConnectionEventArgs.cs @@ -0,0 +1,24 @@ +using Impostor.Api.Net.Messages; + +namespace Impostor.Hazel +{ + public struct NewConnectionEventArgs + { + /// <summary> + /// The data received from the client in the handshake. + /// This data is yours. Remember to recycle it. + /// </summary> + public readonly IMessageReader HandshakeData; + + /// <summary> + /// The <see cref="Connection"/> to the new client. + /// </summary> + public readonly Connection Connection; + + public NewConnectionEventArgs(IMessageReader handshakeData, Connection connection) + { + this.HandshakeData = handshakeData; + this.Connection = connection; + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/ObjectPoolCustom.cs b/Impostor-dev/src/Impostor.Hazel/ObjectPoolCustom.cs new file mode 100644 index 0000000..5c9ef9b --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/ObjectPoolCustom.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; + +namespace Impostor.Hazel +{ + /// <summary> + /// A fairly simple object pool for items that will be created a lot. + /// </summary> + /// <typeparam name="T">The type that is pooled.</typeparam> + /// <threadsafety static="true" instance="true"/> + public sealed class ObjectPoolCustom<T> where T : IRecyclable + { + private int numberCreated; + public int NumberCreated { get { return numberCreated; } } + + public int NumberInUse { get { return this.inuse.Count; } } + public int NumberNotInUse { get { return this.pool.Count; } } + public int Size { get { return this.NumberInUse + this.NumberNotInUse; } } + +#if HAZEL_BAG + private readonly ConcurrentBag<T> pool = new ConcurrentBag<T>(); +#else + private readonly List<T> pool = new List<T>(); +#endif + + // Unavailable objects + private readonly ConcurrentDictionary<T, bool> inuse = new ConcurrentDictionary<T, bool>(); + + /// <summary> + /// The generator for creating new objects. + /// </summary> + /// <returns></returns> + private readonly Func<T> objectFactory; + + /// <summary> + /// Internal constructor for our ObjectPool. + /// </summary> + internal ObjectPoolCustom(Func<T> objectFactory) + { + this.objectFactory = objectFactory; + } + + /// <summary> + /// Returns a pooled object of type T, if none are available another is created. + /// </summary> + /// <returns>An instance of T.</returns> + internal T GetObject() + { +#if HAZEL_BAG + if (!pool.TryTake(out T item)) + { + Interlocked.Increment(ref numberCreated); + item = objectFactory.Invoke(); + } +#else + T item; + lock (this.pool) + { + if (this.pool.Count > 0) + { + var idx = this.pool.Count - 1; + item = this.pool[idx]; + this.pool.RemoveAt(idx); + } + else + { + Interlocked.Increment(ref numberCreated); + item = objectFactory.Invoke(); + } + } +#endif + + if (!inuse.TryAdd(item, true)) + { + throw new Exception("Duplicate pull " + typeof(T).Name); + } + + return item; + } + + /// <summary> + /// Returns an object to the pool. + /// </summary> + /// <param name="item">The item to return.</param> + internal void PutObject(T item) + { + if (inuse.TryRemove(item, out bool b)) + { +#if HAZEL_BAG + pool.Add(item); +#else + lock (this.pool) + { + pool.Add(item); + } +#endif + } + else + { +#if DEBUG + throw new Exception("Duplicate add " + typeof(T).Name); +#endif + } + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/SendOptionInternal.cs b/Impostor-dev/src/Impostor.Hazel/Udp/SendOptionInternal.cs new file mode 100644 index 0000000..c0c4e21 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/SendOptionInternal.cs @@ -0,0 +1,33 @@ +namespace Impostor.Hazel.Udp +{ + /// <summary> + /// Extra internal states for SendOption enumeration when using UDP. + /// </summary> + public enum UdpSendOption : byte + { + /// <summary> + /// Hello message for initiating communication. + /// </summary> + Hello = 8, + + /// <summary> + /// A single byte of continued existence + /// </summary> + Ping = 12, + + /// <summary> + /// Message for discontinuing communication. + /// </summary> + Disconnect = 9, + + /// <summary> + /// Message acknowledging the receipt of a message. + /// </summary> + Acknowledgement = 10, + + /// <summary> + /// Message that is part of a larger, fragmented message. + /// </summary> + Fragment = 11, + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcastListener.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcastListener.cs new file mode 100644 index 0000000..ed7b68d --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcastListener.cs @@ -0,0 +1,156 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace Impostor.Hazel.Udp +{ + public class BroadcastPacket + { + public string Data; + public DateTime ReceiveTime; + public IPEndPoint Sender; + + public BroadcastPacket(string data, IPEndPoint sender) + { + this.Data = data; + this.Sender = sender; + this.ReceiveTime = DateTime.Now; + } + + public string GetAddress() + { + return this.Sender.Address.ToString(); + } + } + + public class UdpBroadcastListener : IDisposable + { + private Socket socket; + private EndPoint endpoint; + private Action<string> logger; + + private byte[] buffer = new byte[1024]; + + private List<BroadcastPacket> packets = new List<BroadcastPacket>(); + + public bool Running { get; private set; } + + /// + public UdpBroadcastListener(int port, Action<string> logger = null) + { + this.logger = logger; + this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + this.socket.EnableBroadcast = true; + this.socket.MulticastLoopback = false; + this.endpoint = new IPEndPoint(IPAddress.Any, port); + this.socket.Bind(this.endpoint); + } + + /// + public void StartListen() + { + if (this.Running) return; + this.Running = true; + + try + { + EndPoint endpt = new IPEndPoint(IPAddress.Any, 0); + this.socket.BeginReceiveFrom(buffer, 0, buffer.Length, SocketFlags.None, ref endpt, this.HandleData, null); + } + catch (NullReferenceException) { } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + this.Dispose(); + } + } + + private void HandleData(IAsyncResult result) + { + this.Running = false; + + int numBytes; + EndPoint endpt = new IPEndPoint(IPAddress.Any, 0); + try + { + numBytes = this.socket.EndReceiveFrom(result, ref endpt); + } + catch (NullReferenceException) + { + // Already disposed + return; + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + this.Dispose(); + return; + } + + if (numBytes < 3 + || buffer[0] != 4 || buffer[1] != 2) + { + this.StartListen(); + return; + } + + IPEndPoint ipEnd = (IPEndPoint)endpt; + string data = UTF8Encoding.UTF8.GetString(buffer, 2, numBytes - 2); + int dataHash = data.GetHashCode(); + + lock (packets) + { + bool found = false; + for (int i = 0; i < this.packets.Count; ++i) + { + var pkt = this.packets[i]; + if (pkt == null || pkt.Data == null) + { + this.packets.RemoveAt(i); + i--; + continue; + } + + if (pkt.Data.GetHashCode() == dataHash + && pkt.Sender.Equals(ipEnd)) + { + this.packets[i].ReceiveTime = DateTime.Now; + break; + } + } + + if (!found) + { + this.packets.Add(new BroadcastPacket(data, ipEnd)); + } + } + + this.StartListen(); + } + + /// + public BroadcastPacket[] GetPackets() + { + lock (this.packets) + { + var output = this.packets.ToArray(); + this.packets.Clear(); + return output; + } + } + + /// + public void Dispose() + { + if (this.socket != null) + { + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + this.socket = null; + } + } + } +}
\ No newline at end of file diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcaster.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcaster.cs new file mode 100644 index 0000000..5fa1cca --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpBroadcaster.cs @@ -0,0 +1,79 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace Impostor.Hazel.Udp +{ + /// + public class UdpBroadcaster : IDisposable + { + private Socket socket; + private byte[] data; + private EndPoint endpoint; + private Action<string> logger; + + /// + public UdpBroadcaster(int port, Action<string> logger = null) + { + this.logger = logger; + this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + this.socket.EnableBroadcast = true; + this.socket.MulticastLoopback = false; + this.endpoint = new IPEndPoint(IPAddress.Broadcast, port); + } + + /// + public void SetData(string data) + { + int len = UTF8Encoding.UTF8.GetByteCount(data); + this.data = new byte[len + 2]; + this.data[0] = 4; + this.data[1] = 2; + + UTF8Encoding.UTF8.GetBytes(data, 0, data.Length, this.data, 2); + } + + /// + public void Broadcast() + { + if (this.data == null) + { + return; + } + + try + { + this.socket.BeginSendTo(data, 0, data.Length, SocketFlags.None, this.endpoint, this.FinishSendTo, null); + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + } + } + + private void FinishSendTo(IAsyncResult evt) + { + try + { + this.socket.EndSendTo(evt); + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + } + } + + /// + public void Dispose() + { + if (this.socket != null) + { + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + this.socket = null; + } + } + } +}
\ No newline at end of file diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpClientConnection.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpClientConnection.cs new file mode 100644 index 0000000..5125ebe --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpClientConnection.cs @@ -0,0 +1,225 @@ +using System; +using System.Buffers; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; +using Microsoft.Extensions.ObjectPool; +using Serilog; + +namespace Impostor.Hazel.Udp +{ + /// <summary> + /// Represents a client's connection to a server that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + public sealed class UdpClientConnection : UdpConnection + { + private static readonly ILogger Logger = Log.ForContext<UdpClientConnection>(); + + /// <summary> + /// The socket we're connected via. + /// </summary> + private readonly UdpClient _socket; + + private readonly Timer _reliablePacketTimer; + private readonly SemaphoreSlim _connectWaitLock; + private Task _listenTask; + + /// <summary> + /// Creates a new UdpClientConnection. + /// </summary> + /// <param name="remoteEndPoint">A <see cref="NetworkEndPoint"/> to connect to.</param> + public UdpClientConnection(IPEndPoint remoteEndPoint, ObjectPool<MessageReader> readerPool, IPMode ipMode = IPMode.IPv4) : base(null, readerPool) + { + EndPoint = remoteEndPoint; + RemoteEndPoint = remoteEndPoint; + IPMode = ipMode; + + _socket = new UdpClient + { + DontFragment = false + }; + + _reliablePacketTimer = new Timer(ManageReliablePacketsInternal, null, 100, Timeout.Infinite); + _connectWaitLock = new SemaphoreSlim(1, 1); + } + + ~UdpClientConnection() + { + Dispose(false); + } + + private async void ManageReliablePacketsInternal(object state) + { + await ManageReliablePackets(); + + try + { + _reliablePacketTimer.Change(100, Timeout.Infinite); + } + catch + { + // ignored + } + } + + /// <inheritdoc /> + protected override ValueTask WriteBytesToConnection(byte[] bytes, int length) + { + return WriteBytesToConnectionReal(bytes, length); + } + + private async ValueTask WriteBytesToConnectionReal(byte[] bytes, int length) + { + try + { + await _socket.SendAsync(bytes, length); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + await DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + /// <inheritdoc /> + public override async ValueTask ConnectAsync(byte[] bytes = null) + { + State = ConnectionState.Connecting; + + try + { + _socket.Connect(RemoteEndPoint); + } + catch (SocketException e) + { + State = ConnectionState.NotConnected; + throw new HazelException("A SocketException occurred while binding to the port.", e); + } + + try + { + _listenTask = Task.Factory.StartNew(ListenAsync, TaskCreationOptions.LongRunning); + } + catch (ObjectDisposedException) + { + // If the socket's been disposed then we can just end there but make sure we're in NotConnected state. + // If we end up here I'm really lost... + State = ConnectionState.NotConnected; + return; + } + catch (SocketException e) + { + Dispose(); + throw new HazelException("A SocketException occurred while initiating a receive operation.", e); + } + + // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present) + // When acknowledged set the state to connected + await SendHello(bytes, () => + { + State = ConnectionState.Connected; + InitializeKeepAliveTimer(); + }); + + await _connectWaitLock.WaitAsync(TimeSpan.FromSeconds(10)); + } + + private async Task ListenAsync() + { + // Start packet handler. + await StartAsync(); + + // Listen. + while (State != ConnectionState.NotConnected) + { + UdpReceiveResult data; + + try + { + data = await _socket.ReceiveAsync(); + } + catch (SocketException e) + { + await DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message); + return; + } + catch (Exception) + { + return; + } + + if (data.Buffer.Length == 0) + { + await DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes"); + return; + } + + // Write to client. + await Pipeline.Writer.WriteAsync(data.Buffer); + } + } + + protected override void SetState(ConnectionState state) + { + if (state == ConnectionState.Connected) + { + _connectWaitLock.Release(); + } + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// You may include optional disconnect data. The SendOption must be unreliable. + /// </summary> + protected override async ValueTask<bool> SendDisconnect(MessageWriter data = null) + { + lock (this) + { + if (_state == ConnectionState.NotConnected) return false; + _state = ConnectionState.NotConnected; + } + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != MessageType.Unreliable) + { + throw new ArgumentException("Disconnect messages can only be unreliable."); + } + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + await _socket.SendAsync(bytes, bytes.Length, RemoteEndPoint); + } + catch { } + + return true; + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + State = ConnectionState.NotConnected; + + try { _socket.Close(); } catch { } + try { _socket.Dispose(); } catch { } + + _reliablePacketTimer.Dispose(); + _connectWaitLock.Dispose(); + + base.Dispose(disposing); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.KeepAlive.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.KeepAlive.cs new file mode 100644 index 0000000..a73291b --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.KeepAlive.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace Impostor.Hazel.Udp +{ + partial class UdpConnection + { + + /// <summary> + /// Class to hold packet data + /// </summary> + public class PingPacket : IRecyclable + { + private static readonly ObjectPoolCustom<PingPacket> PacketPool = new ObjectPoolCustom<PingPacket>(() => new PingPacket()); + + public readonly Stopwatch Stopwatch = new Stopwatch(); + + internal static PingPacket GetObject() + { + return PacketPool.GetObject(); + } + + public void Recycle() + { + Stopwatch.Stop(); + PacketPool.PutObject(this); + } + } + + internal ConcurrentDictionary<ushort, PingPacket> activePingPackets = new ConcurrentDictionary<ushort, PingPacket>(); + + /// <summary> + /// The interval from data being received or transmitted to a keepalive packet being sent in milliseconds. + /// </summary> + /// <remarks> + /// <para> + /// Keepalive packets serve to close connections when an endpoint abruptly disconnects and to ensure than any + /// NAT devices do not close their translation for our argument. By ensuring there is regular contact the + /// connection can detect and prevent these issues. + /// </para> + /// <para> + /// The default value is 10 seconds, set to System.Threading.Timeout.Infinite to disable keepalive packets. + /// </para> + /// </remarks> + public int KeepAliveInterval + { + get + { + return keepAliveInterval; + } + + set + { + keepAliveInterval = value; + ResetKeepAliveTimer(); + } + } + private int keepAliveInterval = 1500; + + public int MissingPingsUntilDisconnect { get; set; } = 6; + private volatile int pingsSinceAck = 0; + + /// <summary> + /// The timer creating keepalive pulses. + /// </summary> + private Timer keepAliveTimer; + + /// <summary> + /// Starts the keepalive timer. + /// </summary> + protected void InitializeKeepAliveTimer() + { + keepAliveTimer = new Timer( + HandleKeepAlive, + null, + keepAliveInterval, + keepAliveInterval + ); + } + + private async void HandleKeepAlive(object state) + { + if (this.State != ConnectionState.Connected) return; + + if (this.pingsSinceAck >= this.MissingPingsUntilDisconnect) + { + this.DisposeKeepAliveTimer(); + await this.DisconnectInternal(HazelInternalErrors.PingsWithoutResponse, $"Sent {this.pingsSinceAck} pings that remote has not responded to."); + return; + } + + try + { + Interlocked.Increment(ref pingsSinceAck); + await SendPing(); + } + catch + { + } + } + + // Pings are special, quasi-reliable packets. + // We send them to trigger responses that validate our connection is alive + // An unacked ping should never be the sole cause of a disconnect. + // Rather, the responses will reset our pingsSinceAck, enough unacked + // pings should cause a disconnect. + private async ValueTask SendPing() + { + ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated); + + byte[] bytes = new byte[3]; + bytes[0] = (byte)UdpSendOption.Ping; + bytes[1] = (byte)(id >> 8); + bytes[2] = (byte)id; + + PingPacket pkt; + if (!this.activePingPackets.TryGetValue(id, out pkt)) + { + pkt = PingPacket.GetObject(); + if (!this.activePingPackets.TryAdd(id, pkt)) + { + throw new Exception("This shouldn't be possible"); + } + } + + pkt.Stopwatch.Restart(); + + await WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogReliableSend(0, bytes.Length); + } + + /// <summary> + /// Resets the keepalive timer to zero. + /// </summary> + private void ResetKeepAliveTimer() + { + try + { + keepAliveTimer.Change(keepAliveInterval, keepAliveInterval); + } + catch { } + } + + /// <summary> + /// Disposes of the keep alive timer. + /// </summary> + private void DisposeKeepAliveTimer() + { + if (this.keepAliveTimer != null) + { + this.keepAliveTimer.Dispose(); + } + + foreach (var kvp in activePingPackets) + { + if (this.activePingPackets.TryRemove(kvp.Key, out var pkt)) + { + pkt.Recycle(); + } + } + } + } +}
\ No newline at end of file diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.Reliable.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.Reliable.cs new file mode 100644 index 0000000..a7a4309 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.Reliable.cs @@ -0,0 +1,491 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; + +namespace Impostor.Hazel.Udp +{ + partial class UdpConnection + { + /// <summary> + /// The starting timeout, in miliseconds, at which data will be resent. + /// </summary> + /// <remarks> + /// <para> + /// For reliable delivery data is resent at specified intervals unless an acknowledgement is received from the + /// receiving device. The ResendTimeout specifies the interval between the packets being resent, each time a packet + /// is resent the interval is increased for that packet until the duration exceeds the <see cref="DisconnectTimeout"/> value. + /// </para> + /// <para> + /// Setting this to its default of 0 will mean the timeout is 2 times the value of the average ping, usually + /// resulting in a more dynamic resend that responds to endpoints on slower or faster connections. + /// </para> + /// </remarks> + public volatile int ResendTimeout = 0; + + /// <summary> + /// Max number of times to resend. 0 == no limit + /// </summary> + public volatile int ResendLimit = 0; + + /// <summary> + /// A compounding multiplier to back off resend timeout. + /// Applied to ping before first timeout when ResendTimeout == 0. + /// </summary> + public volatile float ResendPingMultiplier = 2; + + /// <summary> + /// Holds the last ID allocated. + /// </summary> + private int lastIDAllocated = 0; + + /// <summary> + /// The packets of data that have been transmitted reliably and not acknowledged. + /// </summary> + internal ConcurrentDictionary<ushort, Packet> reliableDataPacketsSent = new ConcurrentDictionary<ushort, Packet>(); + + /// <summary> + /// Packet ids that have not been received, but are expected. + /// </summary> + private HashSet<ushort> reliableDataPacketsMissing = new HashSet<ushort>(); + + /// <summary> + /// The packet id that was received last. + /// </summary> + private volatile ushort reliableReceiveLast = ushort.MaxValue; + + private object PingLock = new object(); + + /// <summary> + /// Returns the average ping to this endpoint. + /// </summary> + /// <remarks> + /// This returns the average ping for a one-way trip as calculated from the reliable packets that have been sent + /// and acknowledged by the endpoint. + /// </remarks> + public float AveragePingMs = 500; + + /// <summary> + /// The maximum times a message should be resent before marking the endpoint as disconnected. + /// </summary> + /// <remarks> + /// Reliable packets will be resent at an interval defined in <see cref="ResendTimeout"/> for the number of times + /// specified here. Once a packet has been retransmitted this number of times and has not been acknowledged the + /// connection will be marked as disconnected and the <see cref="Connection.Disconnected">Disconnected</see> event + /// will be invoked. + /// </remarks> + public volatile int DisconnectTimeout = 5000; + + /// <summary> + /// Class to hold packet data + /// </summary> + public class Packet : IRecyclable + { + /// <summary> + /// Object pool for this event. + /// </summary> + public static readonly ObjectPoolCustom<Packet> PacketPool = new ObjectPoolCustom<Packet>(() => new Packet()); + + /// <summary> + /// Returns an instance of this object from the pool. + /// </summary> + /// <returns></returns> + internal static Packet GetObject() + { + return PacketPool.GetObject(); + } + + public ushort Id; + private byte[] Data; + private UdpConnection Connection; + private int Length; + + public int NextTimeout; + public volatile bool Acknowledged; + + public Action AckCallback; + + public int Retransmissions; + public Stopwatch Stopwatch = new Stopwatch(); + + Packet() + { + } + + internal void Set(ushort id, UdpConnection connection, byte[] data, int length, int timeout, Action ackCallback) + { + this.Id = id; + this.Data = data; + this.Connection = connection; + this.Length = length; + + this.Acknowledged = false; + this.NextTimeout = timeout; + this.AckCallback = ackCallback; + this.Retransmissions = 0; + + this.Stopwatch.Restart(); + } + + // Packets resent + public async ValueTask<int> Resend() + { + var connection = this.Connection; + if (!this.Acknowledged && connection != null) + { + long lifetime = this.Stopwatch.ElapsedMilliseconds; + if (lifetime >= connection.DisconnectTimeout) + { + if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self)) + { + await connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {lifetime}ms ({self.Retransmissions} resends)"); + + self.Recycle(); + } + + return 0; + } + + if (lifetime >= this.NextTimeout) + { + ++this.Retransmissions; + if (connection.ResendLimit != 0 + && this.Retransmissions > connection.ResendLimit) + { + if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self)) + { + await connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {self.Retransmissions} resends ({lifetime}ms)"); + + self.Recycle(); + } + + return 0; + } + + this.NextTimeout += (int)Math.Min(this.NextTimeout * connection.ResendPingMultiplier, 1000); + try + { + await connection.WriteBytesToConnection(this.Data, this.Length); + connection.Statistics.LogMessageResent(); + return 1; + } + catch (InvalidOperationException) + { + await connection.DisconnectInternal(HazelInternalErrors.ConnectionDisconnected, "Could not resend data as connection is no longer connected"); + } + } + } + + return 0; + } + + /// <summary> + /// Returns this object back to the object pool from whence it came. + /// </summary> + public void Recycle() + { + this.Acknowledged = true; + this.Connection = null; + + PacketPool.PutObject(this); + } + } + + internal async ValueTask<int> ManageReliablePackets() + { + int output = 0; + if (this.reliableDataPacketsSent.Count > 0) + { + foreach (var kvp in this.reliableDataPacketsSent) + { + Packet pkt = kvp.Value; + + try + { + output += await pkt.Resend(); + } + catch { } + } + } + + return output; + } + + /// <summary> + /// Adds a 2 byte ID to the packet at offset and stores the packet reference for retransmission. + /// </summary> + /// <param name="buffer">The buffer to attach to.</param> + /// <param name="offset">The offset to attach at.</param> + /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param> + protected void AttachReliableID(byte[] buffer, int offset, int sendLength, Action ackCallback = null) + { + ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated); + + buffer[offset] = (byte)(id >> 8); + buffer[offset + 1] = (byte)id; + + Packet packet = Packet.GetObject(); + packet.Set( + id, + this, + buffer, + sendLength, + ResendTimeout > 0 ? ResendTimeout : (int)Math.Min(AveragePingMs * this.ResendPingMultiplier, 300), + ackCallback); + + if (!reliableDataPacketsSent.TryAdd(id, packet)) + { + throw new Exception("That shouldn't be possible"); + } + } + + public static int ClampToInt(float value, int min, int max) + { + if (value < min) return min; + if (value > max) return max; + return (int)value; + } + + /// <summary> + /// Sends the bytes reliably and stores the send. + /// </summary> + /// <param name="sendOption"></param> + /// <param name="data">The byte array to write to.</param> + /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param> + private async ValueTask ReliableSend(byte sendOption, byte[] data, Action ackCallback = null) + { + //Inform keepalive not to send for a while + ResetKeepAliveTimer(); + + byte[] bytes = new byte[data.Length + 3]; + + //Add message type + bytes[0] = sendOption; + + //Add reliable ID + AttachReliableID(bytes, 1, bytes.Length, ackCallback); + + //Copy data into new array + Buffer.BlockCopy(data, 0, bytes, bytes.Length - data.Length, data.Length); + + //Write to connection + await WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogReliableSend(data.Length, bytes.Length); + } + + /// <summary> + /// Handles a reliable message being received and invokes the data event. + /// </summary> + /// <param name="message">The buffer received.</param> + private async ValueTask ReliableMessageReceive(MessageReader message) + { + if (await ProcessReliableReceive(message.Buffer, 1)) + { + message.Offset += 3; + message.Length -= 3; + message.Position = 0; + + await InvokeDataReceived(message, MessageType.Reliable); + } + + Statistics.LogReliableReceive(message.Length - 3, message.Length); + } + + /// <summary> + /// Handles receives from reliable packets. + /// </summary> + /// <param name="bytes">The buffer containing the data.</param> + /// <param name="offset">The offset of the reliable header.</param> + /// <returns>Whether the packet was a new packet or not.</returns> + private async ValueTask<bool> ProcessReliableReceive(ReadOnlyMemory<byte> bytes, int offset) + { + var b1 = bytes.Span[offset]; + var b2 = bytes.Span[offset + 1]; + + //Get the ID form the packet + var id = (ushort)((b1 << 8) + b2); + + //Send an acknowledgement + await SendAck(id); + + /* + * It gets a little complicated here (note the fact I'm actually using a multiline comment for once...) + * + * In a simple world if our data is greater than the last reliable packet received (reliableReceiveLast) + * then it is guaranteed to be a new packet, if it's not we can see if we are missing that packet (lookup + * in reliableDataPacketsMissing). + * + * --------rrl############# (1) + * + * (where --- are packets received already and #### are packets that will be counted as new) + * + * Unfortunately if id becomes greater than 65535 it will loop back to zero so we will add a pointer that + * specifies any packets with an id behind it are also new (overwritePointer). + * + * ####op----------rrl##### (2) + * + * ------rll#########op---- (3) + * + * Anything behind than the reliableReceiveLast pointer (but greater than the overwritePointer is either a + * missing packet or something we've already received so when we change the pointers we need to make sure + * we keep note of what hasn't been received yet (reliableDataPacketsMissing). + * + * So... + */ + + lock (reliableDataPacketsMissing) + { + //Calculate overwritePointer + ushort overwritePointer = (ushort)(reliableReceiveLast - 32768); + + //Calculate if it is a new packet by examining if it is within the range + bool isNew; + if (overwritePointer < reliableReceiveLast) + isNew = id > reliableReceiveLast || id <= overwritePointer; //Figure (2) + else + isNew = id > reliableReceiveLast && id <= overwritePointer; //Figure (3) + + //If it's new or we've not received anything yet + if (isNew) + { + // Mark items between the most recent receive and the id received as missing + if (id > reliableReceiveLast) + { + for (ushort i = (ushort)(reliableReceiveLast + 1); i < id; i++) + { + reliableDataPacketsMissing.Add(i); + } + } + else + { + int cnt = (ushort.MaxValue - reliableReceiveLast) + id; + for (ushort i = 1; i < cnt; ++i) + { + reliableDataPacketsMissing.Add((ushort)(i + reliableReceiveLast)); + } + } + + //Update the most recently received + reliableReceiveLast = id; + } + + //Else it could be a missing packet + else + { + //See if we're missing it, else this packet is a duplicate as so we return false + if (!reliableDataPacketsMissing.Remove(id)) + { + return false; + } + } + } + + return true; + } + + /// <summary> + /// Handles acknowledgement packets to us. + /// </summary> + /// <param name="bytes">The buffer containing the data.</param> + private void AcknowledgementMessageReceive(ReadOnlySpan<byte> bytes) + { + this.pingsSinceAck = 0; + + ushort id = (ushort)((bytes[1] << 8) + bytes[2]); + AcknowledgeMessageId(id); + + if (bytes.Length == 4) + { + byte recentPackets = bytes[3]; + for (int i = 1; i <= 8; ++i) + { + if ((recentPackets & 1) != 0) + { + AcknowledgeMessageId((ushort)(id - i)); + } + + recentPackets >>= 1; + } + } + + Statistics.LogReliableReceive(0, bytes.Length); + } + + private void AcknowledgeMessageId(ushort id) + { + // Dispose of timer and remove from dictionary + if (reliableDataPacketsSent.TryRemove(id, out Packet packet)) + { + float rt = packet.Stopwatch.ElapsedMilliseconds; + + packet.AckCallback?.Invoke(); + packet.Recycle(); + + lock (PingLock) + { + this.AveragePingMs = Math.Max(50, this.AveragePingMs * .7f + rt * .3f); + } + } + else if (this.activePingPackets.TryRemove(id, out PingPacket pingPkt)) + { + float rt = pingPkt.Stopwatch.ElapsedMilliseconds; + + pingPkt.Recycle(); + + lock (PingLock) + { + this.AveragePingMs = Math.Max(50, this.AveragePingMs * .7f + rt * .3f); + } + } + } + + /// <summary> + /// Sends an acknowledgement for a packet given its identification bytes. + /// </summary> + /// <param name="byte1">The first identification byte.</param> + /// <param name="byte2">The second identification byte.</param> + private async ValueTask SendAck(ushort id) + { + byte recentPackets = 0; + lock (this.reliableDataPacketsMissing) + { + for (int i = 1; i <= 8; ++i) + { + if (!this.reliableDataPacketsMissing.Contains((ushort)(id - i))) + { + recentPackets |= (byte)(1 << (i - 1)); + } + } + } + + byte[] bytes = new byte[] + { + (byte)UdpSendOption.Acknowledgement, + (byte)(id >> 8), + (byte)(id >> 0), + recentPackets + }; + + try + { + await WriteBytesToConnection(bytes, bytes.Length); + } + catch (InvalidOperationException) { } + } + + private void DisposeReliablePackets() + { + foreach (var kvp in reliableDataPacketsSent) + { + if (this.reliableDataPacketsSent.TryRemove(kvp.Key, out var pkt)) + { + pkt.Recycle(); + } + } + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.cs new file mode 100644 index 0000000..5288d3c --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnection.cs @@ -0,0 +1,312 @@ +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; +using Microsoft.Extensions.ObjectPool; +using Serilog; + +namespace Impostor.Hazel.Udp +{ + /// <summary> + /// Represents a connection that uses the UDP protocol. + /// </summary> + /// <inheritdoc /> + public abstract partial class UdpConnection : NetworkConnection + { + protected static readonly byte[] EmptyDisconnectBytes = { (byte)UdpSendOption.Disconnect }; + + private static readonly ILogger Logger = Log.ForContext<UdpConnection>(); + private readonly ConnectionListener _listener; + private readonly ObjectPool<MessageReader> _readerPool; + private readonly CancellationTokenSource _stoppingCts; + + private bool _isDisposing; + private bool _isFirst = true; + private Task _executingTask; + + protected UdpConnection(ConnectionListener listener, ObjectPool<MessageReader> readerPool) + { + _listener = listener; + _readerPool = readerPool; + _stoppingCts = new CancellationTokenSource(); + + Pipeline = Channel.CreateUnbounded<byte[]>(new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = true + }); + } + + internal Channel<byte[]> Pipeline { get; } + + public Task StartAsync() + { + // Store the task we're executing + _executingTask = Task.Factory.StartNew(ReadAsync, TaskCreationOptions.LongRunning); + + // If the task is completed then return it, this will bubble cancellation and failure to the caller + if (_executingTask.IsCompleted) + { + return _executingTask; + } + + // Otherwise it's running + return Task.CompletedTask; + } + + public void Stop() + { + // Stop called without start + if (_executingTask == null) + { + return; + } + + // Signal cancellation to methods. + _stoppingCts.Cancel(); + + try + { + // Cancel reader. + Pipeline.Writer.Complete(); + } + catch (ChannelClosedException) + { + // Already done. + } + + // Remove references. + if (!_isDisposing) + { + Dispose(true); + } + } + + private async Task ReadAsync() + { + var reader = new MessageReader(_readerPool); + + while (!_stoppingCts.IsCancellationRequested) + { + var result = await Pipeline.Reader.ReadAsync(_stoppingCts.Token); + + try + { + reader.Update(result); + + await HandleReceive(reader); + } + catch (Exception e) + { + Logger.Error(e, "Exception during ReadAsync"); + Dispose(true); + break; + } + } + } + + /// <summary> + /// Writes the given bytes to the connection. + /// </summary> + /// <param name="bytes">The bytes to write.</param> + /// <param name="length"></param> + protected abstract ValueTask WriteBytesToConnection(byte[] bytes, int length); + + /// <inheritdoc/> + public override async ValueTask SendAsync(IMessageWriter msg) + { + if (this._state != ConnectionState.Connected) + throw new InvalidOperationException("Could not send data as this Connection is not connected. Did you disconnect?"); + + byte[] buffer = new byte[msg.Length]; + Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length); + + switch (msg.SendOption) + { + case MessageType.Reliable: + ResetKeepAliveTimer(); + + AttachReliableID(buffer, 1, buffer.Length); + await WriteBytesToConnection(buffer, buffer.Length); + Statistics.LogReliableSend(buffer.Length - 3, buffer.Length); + break; + + default: + await WriteBytesToConnection(buffer, buffer.Length); + Statistics.LogUnreliableSend(buffer.Length - 1, buffer.Length); + break; + } + } + + /// <inheritdoc/> + /// <remarks> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Connection_SendBytes_General']/*" /> + /// <para> + /// Udp connections can currently send messages using <see cref="SendOption.None"/> and + /// <see cref="SendOption.Reliable"/>. Fragmented messages are not currently supported and will default to + /// <see cref="SendOption.None"/> until implemented. + /// </para> + /// </remarks> + public override async ValueTask SendBytes(byte[] bytes, MessageType sendOption = MessageType.Unreliable) + { + //Add header information and send + await HandleSend(bytes, (byte)sendOption); + } + + /// <summary> + /// Handles the reliable/fragmented sending from this connection. + /// </summary> + /// <param name="data">The data being sent.</param> + /// <param name="sendOption">The <see cref="SendOption"/> specified as its byte value.</param> + /// <param name="ackCallback">The callback to invoke when this packet is acknowledged.</param> + /// <returns>The bytes that should actually be sent.</returns> + protected async ValueTask HandleSend(byte[] data, byte sendOption, Action ackCallback = null) + { + switch (sendOption) + { + case (byte)UdpSendOption.Ping: + case (byte)MessageType.Reliable: + case (byte)UdpSendOption.Hello: + await ReliableSend(sendOption, data, ackCallback); + break; + + //Treat all else as unreliable + default: + await UnreliableSend(sendOption, data); + break; + } + } + + /// <summary> + /// Handles the receiving of data. + /// </summary> + /// <param name="message">The buffer containing the bytes received.</param> + protected async ValueTask HandleReceive(MessageReader message) + { + // Check if the first message received is the hello packet. + if (_isFirst) + { + _isFirst = false; + + // Slice 4 bytes to get handshake data. + if (_listener != null) + { + using (var handshake = message.Copy(4)) + { + await _listener.InvokeNewConnection(handshake, this); + } + } + } + + switch (message.Buffer[0]) + { + //Handle reliable receives + case (byte)MessageType.Reliable: + await ReliableMessageReceive(message); + break; + + //Handle acknowledgments + case (byte)UdpSendOption.Acknowledgement: + AcknowledgementMessageReceive(message.Buffer); + break; + + //We need to acknowledge hello and ping messages but dont want to invoke any events! + case (byte)UdpSendOption.Ping: + await ProcessReliableReceive(message.Buffer, 1); + Statistics.LogHelloReceive(message.Length); + break; + case (byte)UdpSendOption.Hello: + await ProcessReliableReceive(message.Buffer, 1); + Statistics.LogHelloReceive(message.Length); + break; + + case (byte)UdpSendOption.Disconnect: + using (var reader = message.Copy(1)) + { + await DisconnectRemote("The remote sent a disconnect request", reader); + } + break; + + //Treat everything else as unreliable + default: + using (var reader = message.Copy(1)) + { + await InvokeDataReceived(reader, MessageType.Unreliable); + } + Statistics.LogUnreliableReceive(message.Length - 1, message.Length); + break; + } + } + + /// <summary> + /// Sends bytes using the unreliable UDP protocol. + /// </summary> + /// <param name="sendOption">The SendOption to attach.</param> + /// <param name="data">The data.</param> + ValueTask UnreliableSend(byte sendOption, byte[] data) + { + return UnreliableSend(sendOption, data, 0, data.Length); + } + + /// <summary> + /// Sends bytes using the unreliable UDP protocol. + /// </summary> + /// <param name="data">The data.</param> + /// <param name="sendOption">The SendOption to attach.</param> + /// <param name="offset"></param> + /// <param name="length"></param> + async ValueTask UnreliableSend(byte sendOption, byte[] data, int offset, int length) + { + byte[] bytes = new byte[length + 1]; + + //Add message type + bytes[0] = sendOption; + + //Copy data into new array + Buffer.BlockCopy(data, offset, bytes, bytes.Length - length, length); + + //Write to connection + await WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogUnreliableSend(length, bytes.Length); + } + + /// <summary> + /// Sends a hello packet to the remote endpoint. + /// </summary> + /// <param name="bytes"></param> + /// <param name="acknowledgeCallback">The callback to invoke when the hello packet is acknowledged.</param> + protected ValueTask SendHello(byte[] bytes, Action acknowledgeCallback) + { + //First byte of handshake is version indicator so add data after + byte[] actualBytes; + if (bytes == null) + { + actualBytes = new byte[1]; + } + else + { + actualBytes = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, actualBytes, 1, bytes.Length); + } + + return HandleSend(actualBytes, (byte)UdpSendOption.Hello, acknowledgeCallback); + } + + /// <inheritdoc/> + protected override void Dispose(bool disposing) + { + if (disposing) + { + _isDisposing = true; + + Stop(); + DisposeKeepAliveTimer(); + DisposeReliablePackets(); + } + + base.Dispose(disposing); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionListener.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionListener.cs new file mode 100644 index 0000000..573a00c --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionListener.cs @@ -0,0 +1,281 @@ +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.Extensions.ObjectPool; +using Serilog; + +namespace Impostor.Hazel.Udp +{ + /// <summary> + /// Listens for new UDP connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class UdpConnectionListener : NetworkConnectionListener + { + private static readonly ILogger Logger = Log.ForContext<UdpConnectionListener>(); + + /// <summary> + /// A callback for early connection rejection. + /// * Return false to reject connection. + /// * A null response is ok, we just won't send anything. + /// </summary> + public AcceptConnectionCheck AcceptConnection; + public delegate bool AcceptConnectionCheck(IPEndPoint endPoint, byte[] input, out byte[] response); + + private readonly UdpClient _socket; + private readonly ObjectPool<MessageReader> _readerPool; + private readonly MemoryPool<byte> _pool; + private readonly Timer _reliablePacketTimer; + private readonly ConcurrentDictionary<EndPoint, UdpServerConnection> _allConnections; + private readonly CancellationTokenSource _stoppingCts; + private readonly UdpConnectionRateLimit _connectionRateLimit; + private Task _executingTask; + + /// <summary> + /// Creates a new UdpConnectionListener for the given <see cref="IPAddress"/>, port and <see cref="IPMode"/>. + /// </summary> + /// <param name="endPoint">The endpoint to listen on.</param> + /// <param name="ipMode"></param> + public UdpConnectionListener(IPEndPoint endPoint, ObjectPool<MessageReader> readerPool, IPMode ipMode = IPMode.IPv4) + { + EndPoint = endPoint; + IPMode = ipMode; + + _readerPool = readerPool; + _pool = MemoryPool<byte>.Shared; + _socket = new UdpClient(endPoint); + + try + { + _socket.DontFragment = false; + } + catch (SocketException) + { + } + + _reliablePacketTimer = new Timer(ManageReliablePackets, null, 100, Timeout.Infinite); + + _allConnections = new ConcurrentDictionary<EndPoint, UdpServerConnection>(); + + _stoppingCts = new CancellationTokenSource(); + _stoppingCts.Token.Register(() => + { + _socket.Dispose(); + }); + + _connectionRateLimit = new UdpConnectionRateLimit(); + } + + public int ConnectionCount => this._allConnections.Count; + + private async void ManageReliablePackets(object state) + { + foreach (var kvp in _allConnections) + { + var sock = kvp.Value; + await sock.ManageReliablePackets(); + } + + try + { + this._reliablePacketTimer.Change(100, Timeout.Infinite); + } + catch { } + } + + /// <inheritdoc /> + public override Task StartAsync() + { + // Store the task we're executing + _executingTask = Task.Factory.StartNew(ListenAsync, TaskCreationOptions.LongRunning); + + // If the task is completed then return it, this will bubble cancellation and failure to the caller + if (_executingTask.IsCompleted) + { + return _executingTask; + } + + // Otherwise it's running + return Task.CompletedTask; + } + + private async Task StopAsync() + { + // Stop called without start + if (_executingTask == null) + { + return; + } + + try + { + // Signal cancellation to the executing method + _stoppingCts.Cancel(); + } + finally + { + // Wait until the task completes or the timeout triggers + await Task.WhenAny(_executingTask, Task.Delay(TimeSpan.FromSeconds(5))); + } + } + + /// <summary> + /// Instructs the listener to begin listening. + /// </summary> + private async Task ListenAsync() + { + try + { + while (!_stoppingCts.IsCancellationRequested) + { + UdpReceiveResult data; + + try + { + data = await _socket.ReceiveAsync(); + + if (data.Buffer.Length == 0) + { + Logger.Fatal("Hazel read 0 bytes from UDP server socket."); + continue; + } + } + catch (SocketException) + { + // Client no longer reachable, pretend it didn't happen + continue; + } + catch (ObjectDisposedException) + { + // Socket was disposed, don't care. + return; + } + + // Get client from active clients + if (!_allConnections.TryGetValue(data.RemoteEndPoint, out var client)) + { + // Check for malformed connection attempts + if (data.Buffer[0] != (byte)UdpSendOption.Hello) + { + continue; + } + + // Check rateLimit. + if (!_connectionRateLimit.IsAllowed(data.RemoteEndPoint.Address)) + { + Logger.Warning("Ratelimited connection attempt from {0}.", data.RemoteEndPoint); + continue; + } + + // Create new client + client = new UdpServerConnection(this, data.RemoteEndPoint, IPMode, _readerPool); + + // Store the client + if (!_allConnections.TryAdd(data.RemoteEndPoint, client)) + { + throw new HazelException("Failed to add a connection. This should never happen."); + } + + // Activate the reader loop of the client + await client.StartAsync(); + } + + // Write to client. + await client.Pipeline.Writer.WriteAsync(data.Buffer); + } + } + catch (Exception e) + { + Logger.Error(e, "Listen loop error"); + } + } + +#if DEBUG + public int TestDropRate = -1; + private int dropCounter = 0; +#endif + + /// <summary> + /// Sends data from the listener socket. + /// </summary> + /// <param name="bytes">The bytes to send.</param> + /// <param name="endPoint">The endpoint to send to.</param> + internal async ValueTask SendData(byte[] bytes, int length, IPEndPoint endPoint) + { + if (length > bytes.Length) return; + +#if DEBUG + if (TestDropRate > 0) + { + if (Interlocked.Increment(ref dropCounter) % TestDropRate == 0) + { + return; + } + } +#endif + + try + { + await _socket.SendAsync(bytes, length, endPoint); + } + catch (SocketException e) + { + Logger.Error(e, "Could not send data as a SocketException occurred"); + } + catch (ObjectDisposedException) + { + //Keep alive timer probably ran, ignore + return; + } + } + + /// <summary> + /// Sends data from the listener socket. + /// </summary> + /// <param name="bytes">The bytes to send.</param> + /// <param name="length"></param> + /// <param name="endPoint">The endpoint to send to.</param> + internal void SendDataSync(byte[] bytes, int length, IPEndPoint endPoint) + { + try + { + _socket.Send(bytes, length, endPoint); + } + catch (SocketException e) + { + Logger.Error(e, "Could not send data sync as a SocketException occurred"); + } + } + + /// <summary> + /// Removes a virtual connection from the list. + /// </summary> + /// <param name="endPoint">The endpoint of the virtual connection.</param> + internal void RemoveConnectionTo(EndPoint endPoint) + { + this._allConnections.TryRemove(endPoint, out var conn); + } + + /// <inheritdoc /> + public override async ValueTask DisposeAsync() + { + foreach (var kvp in _allConnections) + { + kvp.Value.Dispose(); + } + + await StopAsync(); + + await _reliablePacketTimer.DisposeAsync(); + + _connectionRateLimit.Dispose(); + + await base.DisposeAsync(); + } + } +} diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionRateLimit.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionRateLimit.cs new file mode 100644 index 0000000..64881d3 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpConnectionRateLimit.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Concurrent; +using System.Net; +using System.Threading; +using Serilog; + +namespace Impostor.Hazel.Udp +{ + public class UdpConnectionRateLimit : IDisposable + { + private static readonly ILogger Logger = Log.ForContext<UdpConnectionRateLimit>(); + + // Allow burst to 5 connections. + // Decrease by 1 every second. + private const int MaxConnections = 5; + private const int FalloffMs = 1000; + + private readonly ConcurrentDictionary<IPAddress, int> _connectionCount; + private readonly Timer _timer; + private bool _isDisposed; + + public UdpConnectionRateLimit() + { + _connectionCount = new ConcurrentDictionary<IPAddress, int>(); + _timer = new Timer(UpdateRateLimit, null, FalloffMs, Timeout.Infinite); + } + + private void UpdateRateLimit(object state) + { + try + { + foreach (var pair in _connectionCount) + { + var count = pair.Value - 1; + if (count > 0) + { + _connectionCount.TryUpdate(pair.Key, count, pair.Value); + } + else + { + _connectionCount.TryRemove(pair); + } + } + } + catch (Exception e) + { + Logger.Error(e, "Exception caught in UpdateRateLimit."); + } + finally + { + if (!_isDisposed) + { + _timer.Change(FalloffMs, Timeout.Infinite); + } + } + } + + public bool IsAllowed(IPAddress key) + { + if (_connectionCount.TryGetValue(key, out var value) && value >= MaxConnections) + { + return false; + } + + _connectionCount.AddOrUpdate(key, _ => 1, (_, i) => i + 1); + return true; + } + + public void Dispose() + { + _isDisposed = true; + _timer.Dispose(); + } + } +}
\ No newline at end of file diff --git a/Impostor-dev/src/Impostor.Hazel/Udp/UdpServerConnection.cs b/Impostor-dev/src/Impostor.Hazel/Udp/UdpServerConnection.cs new file mode 100644 index 0000000..22eed98 --- /dev/null +++ b/Impostor-dev/src/Impostor.Hazel/Udp/UdpServerConnection.cs @@ -0,0 +1,97 @@ +using System; +using System.Net; +using System.Threading.Tasks; +using Impostor.Api.Net.Messages; +using Microsoft.Extensions.ObjectPool; + +namespace Impostor.Hazel.Udp +{ + /// <summary> + /// Represents a servers's connection to a client that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + internal sealed class UdpServerConnection : UdpConnection + { + /// <summary> + /// The connection listener that we use the socket of. + /// </summary> + /// <remarks> + /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that + /// created this connection and is hence the listener this conenction sends and receives via. + /// </remarks> + public UdpConnectionListener Listener { get; private set; } + + /// <summary> + /// Creates a UdpConnection for the virtual connection to the endpoint. + /// </summary> + /// <param name="listener">The listener that created this connection.</param> + /// <param name="endPoint">The endpoint that we are connected to.</param> + /// <param name="IPMode">The IPMode we are connected using.</param> + internal UdpServerConnection(UdpConnectionListener listener, IPEndPoint endPoint, IPMode IPMode, ObjectPool<MessageReader> readerPool) : base(listener, readerPool) + { + this.Listener = listener; + this.RemoteEndPoint = endPoint; + this.EndPoint = endPoint; + this.IPMode = IPMode; + + State = ConnectionState.Connected; + this.InitializeKeepAliveTimer(); + } + + /// <inheritdoc /> + protected override async ValueTask WriteBytesToConnection(byte[] bytes, int length) + { + await Listener.SendData(bytes, length, RemoteEndPoint); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override ValueTask ConnectAsync(byte[] bytes = null) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected override async ValueTask<bool> SendDisconnect(MessageWriter data = null) + { + lock (this) + { + if (this._state != ConnectionState.Connected) return false; + this._state = ConnectionState.NotConnected; + } + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != MessageType.Unreliable) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + await Listener.SendData(bytes, bytes.Length, RemoteEndPoint); + } + catch { } + + return true; + } + + protected override void Dispose(bool disposing) + { + Listener.RemoveConnectionTo(RemoteEndPoint); + + if (disposing) + { + SendDisconnect(); + } + + base.Dispose(disposing); + } + } +} |