diff options
Diffstat (limited to 'Tools/Hazel-Networking/Hazel')
57 files changed, 12055 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel/ByteSpan.cs b/Tools/Hazel-Networking/Hazel/ByteSpan.cs new file mode 100644 index 0000000..7cfa3b5 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ByteSpan.cs @@ -0,0 +1,191 @@ +using System; + +namespace Hazel +{ + /// <summary> + /// This is a minimal implementation of `System.Span` in .NET 5.0 + /// </summary> + public struct ByteSpan + { + private readonly byte[] array_; + + /// <summary> + /// Createa a new span object containing an entire array + /// </summary> + public ByteSpan(byte[] array) + { + if (array == null) + { + this.array_ = null; + this.Offset = 0; + this.Length = 0; + } + else + { + this.array_ = array; + this.Offset = 0; + this.Length = array.Length; + } + } + + /// <summary> + /// Creates a new span object containing a subset of an array + /// </summary> + public ByteSpan(byte[] array, int offset, int length) + { + if (array == null) + { + if (offset != 0) + { + throw new ArgumentException("Invalid offset", nameof(offset)); + } + if (length != 0) + { + throw new ArgumentException("Invalid length", nameof(offset)); + } + + this.array_ = null; + this.Offset = 0; + this.Length = 0; + } + else + { + if (offset < 0 || offset > array.Length) + { + throw new ArgumentException("Invalid offset", nameof(offset)); + } + if (length < 0) + { + throw new ArgumentException($"Invalid length: {length}", nameof(length)); + } + if ((offset + length) > array.Length) + { + throw new ArgumentException($"Invalid length. Length: {length} Offset: {offset} Array size: {array.Length}", nameof(length)); + } + + this.array_ = array; + this.Offset = offset; + this.Length = length; + } + } + + /// <summary> + /// Returns the underlying array. + /// + /// WARNING: This does not return the span, but the entire underlying storage block + /// </summary> + public byte[] GetUnderlyingArray() + { + return this.array_; + } + + /// <summary> + /// Returns the offset into the underlying array + /// </summary> + public int Offset { get; } + + /// <summary> + /// Returns the length of the current span + /// </summary> + public int Length { get; } + + /// <summary> + /// Gets the span element at the specified index + /// </summary> + public byte this[int index] + { + get + { + if (index < 0 || index >= this.Length) + { + throw new IndexOutOfRangeException(); + } + + return this.array_[this.Offset + index]; + } + set + { + if (index < 0 || index >= this.Length) + { + throw new IndexOutOfRangeException(); + } + + this.array_[this.Offset + index] = value; + } + } + + /// <summary> + /// Create a new span that is a subset of this span [offset, this.Length-offset) + /// </summary> + public ByteSpan Slice(int offset) + { + return Slice(offset, this.Length - offset); + } + + /// <summary> + /// Create a new span that is a subset of this span [offset, length) + /// </summary> + public ByteSpan Slice(int offset, int length) + { + return new ByteSpan(this.array_, this.Offset + offset, length); + } + + /// <summary> + /// Copies the contents of the span to an array + /// </summary> + public void CopyTo(byte[] array, int offset) + { + CopyTo(new ByteSpan(array, offset, array.Length - offset)); + } + + /// <summary> + /// Copies the contents of the span to another span + /// </summary> + public void CopyTo(ByteSpan destination) + { + if (destination.Length < this.Length) + { + throw new ArgumentException("Destination span is shorter than source", nameof(destination)); + } + + if (Length > 0) + { + Buffer.BlockCopy(this.array_, this.Offset, destination.array_, destination.Offset, this.Length); + } + } + + /// <summary> + /// Create a new array with the contents of this span + /// </summary> + public byte[] ToArray() + { + byte[] result = new byte[Length]; + CopyTo(result); + return result; + } + + public override string ToString() + { + return string.Join(" ", this.ToArray()); + } + + /// <summary> + /// Implicit conversion from byte[] -> ByteSpan + /// </summary> + public static implicit operator ByteSpan(byte[] array) + { + return new ByteSpan(array); + } + + /// <summary> + /// Retuns an empty span object + /// </summary> + public static ByteSpan Empty + { + get + { + return new ByteSpan(null); + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs new file mode 100644 index 0000000..3a9d1ac --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs @@ -0,0 +1,131 @@ +namespace Hazel +{ + /// <summary> + /// Extension functions for (en/de)coding integer values + /// </summary> + public static class ByteSpanBigEndianExtensions + { + // Write a 16-bit integer in big-endian format to output[0..2) + public static void WriteBigEndian16(this ByteSpan output, ushort value, int offset = 0) + { + output[offset + 0] = (byte)(value >> 8); + output[offset + 1] = (byte)(value >> 0); + } + + // Write a 24-bit integer in big-endian format to output[0..3) + public static void WriteBigEndian24(this ByteSpan output, uint value, int offset = 0) + { + output[offset + 0] = (byte)(value >> 16); + output[offset + 1] = (byte)(value >> 8); + output[offset + 2] = (byte)(value >> 0); + } + + // Write a 32-bit integer in big-endian format to output[0..4) + public static void WriteBigEndian32(this ByteSpan output, uint value, int offset) + { + output[offset + 0] = (byte)(value >> 24); + output[offset + 1] = (byte)(value >> 16); + output[offset + 2] = (byte)(value >> 8); + output[offset + 3] = (byte)(value >> 0); + } + + // Write a 48-bit integer in big-endian format to output[0..6) + public static void WriteBigEndian48(this ByteSpan output, ulong value, int offset = 0) + { + output[offset + 0] = (byte)(value >> 40); + output[offset + 1] = (byte)(value >> 32); + output[offset + 2] = (byte)(value >> 24); + output[offset + 3] = (byte)(value >> 16); + output[offset + 4] = (byte)(value >> 8); + output[offset + 5] = (byte)(value >> 0); + } + + // Write a 64-bit integer in big-endian format to output[0..8) + public static void WriteBigEndian64(this ByteSpan output, ulong value, int offset = 0) + { + output[offset + 0] = (byte)(value >> 56); + output[offset + 1] = (byte)(value >> 48); + output[offset + 2] = (byte)(value >> 40); + output[offset + 3] = (byte)(value >> 32); + output[offset + 4] = (byte)(value >> 24); + output[offset + 5] = (byte)(value >> 16); + output[offset + 6] = (byte)(value >> 8); + output[offset + 7] = (byte)(value >> 0); + } + + // Read a 16-bit integer in big-endian format from input[0..2) + public static ushort ReadBigEndian16(this ByteSpan input, int offset = 0) + { + ushort value = 0; + value |= (ushort)(input[offset + 0] << 8); + value |= (ushort)(input[offset + 1] << 0); + return value; + } + + // Read a 24-bit integer in big-endian format from input[0..3) + public static uint ReadBigEndian24(this ByteSpan input, int offset = 0) + { + uint value = 0; + value |= (uint)input[offset + 0] << 16; + value |= (uint)input[offset + 1] << 8; + value |= (uint)input[offset + 2] << 0; + return value; + } + + // Read a 48-bit integer in big-endian format from input[0..3) + public static ulong ReadBigEndian48(this ByteSpan input, int offset = 0) + { + ulong value = 0; + value |= (ulong)input[offset + 0] << 40; + value |= (ulong)input[offset + 1] << 32; + value |= (ulong)input[offset + 2] << 24; + value |= (ulong)input[offset + 3] << 16; + value |= (ulong)input[offset + 4] << 8; + value |= (ulong)input[offset + 5] << 0; + return value; + } + } + + public static class ByteSpanLittleEndianExtensions + { + // Read a 24-bit integer in little-endian format from input[0..3) + public static uint ReadLittleEndian24(this ByteSpan input, int offset = 0) + { + uint value = 0; + value |= (uint)input[offset + 0]; + value |= (uint)input[offset + 1] << 8; + value |= (uint)input[offset + 2] << 16; + return value; + } + + // Read a 24-bit integer in little-endian format from input[0..4) + public static uint ReadLittleEndian32(this ByteSpan input, int offset = 0) + { + uint value = 0; + value |= (uint)input[offset + 0]; + value |= (uint)input[offset + 1] << 8; + value |= (uint)input[offset + 2] << 16; + value |= (uint)input[offset + 3] << 24; + return value; + } + + /// <summary> + /// Reuse an existing span if there is enough space, + /// otherwise allocate new storage + /// </summary> + /// <param name="source"> + /// Source span we should attempt to reuse + /// </param> + /// <param name="requiredSize">Required size (bytes)</param> + public static ByteSpan ReuseSpanIfPossible(this ByteSpan source, int requiredSize) + { + if (source.Length >= requiredSize) + { + return source.Slice(0, requiredSize); + } + + return new byte[requiredSize]; + } + + } +} diff --git a/Tools/Hazel-Networking/Hazel/Connection.cs b/Tools/Hazel-Networking/Hazel/Connection.cs new file mode 100644 index 0000000..da2f59a --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Connection.cs @@ -0,0 +1,234 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Net.Sockets; +using System.Net; +using System.Threading; + +namespace Hazel +{ + /// <summary> + /// Base class for all connections. + /// </summary> + /// <remarks> + /// <para> + /// Connection is the base class for all connections that Hazel can make. It provides common functionality and a + /// standard interface to allow connections to be swapped easily. + /// </para> + /// <para> + /// Any class inheriting from Connection should provide the 3 standard guarantees that Hazel provides: + /// <list type="bullet"> + /// <item> + /// <description>Thread Safe</description> + /// </item> + /// <item> + /// <description>Connection Orientated</description> + /// </item> + /// <item> + /// <description>Packet/Message Based</description> + /// </item> + /// </list> + /// </para> + /// </remarks> + /// <threadsafety static="true" instance="true"/> + public abstract class Connection : IDisposable + { + /// <summary> + /// Called when a message has been received. + /// </summary> + /// <remarks> + /// <para> + /// DataReceived is invoked everytime a message is received from the end point of this connection, the message + /// that was received can be found in the <see cref="DataReceivedEventArgs"/> alongside other information from the + /// event. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpClientExample.cs"/> + /// </example> + public event Action<DataReceivedEventArgs> DataReceived; + + public int TestLagMs = -1; + public int TestDropRate = 0; + protected int testDropCount = 0; + + /// <summary> + /// Called when the end point disconnects or an error occurs. + /// </summary> + /// <remarks> + /// <para> + /// Disconnected is invoked when the connection is closed due to an exception occuring or because the remote + /// end point disconnected. If it was invoked due to an exception occuring then the exception is available + /// in the <see cref="DisconnectedEventArgs"/> passed with the event. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpClientExample.cs"/> + /// </example> + public event EventHandler<DisconnectedEventArgs> Disconnected; + + /// <summary> + /// The remote end point of this Connection. + /// </summary> + /// <remarks> + /// This is the end point that this connection is connected to (i.e. the other device). This returns an abstract + /// <see cref="ConnectionEndPoint"/> which can then be cast to an appropriate end point depending on the + /// connection type. + /// </remarks> + public IPEndPoint EndPoint { get; protected set; } + + public IPMode IPMode { get; protected set; } + + /// <summary> + /// The traffic statistics about this Connection. + /// </summary> + /// <remarks> + /// Contains statistics about the number of messages and bytes sent and received by this connection. + /// </remarks> + public ConnectionStatistics Statistics { get; protected set; } + + /// <summary> + /// The state of this connection. + /// </summary> + /// <remarks> + /// All implementers should be aware that when this is set to ConnectionState.Connected it will + /// release all threads that are blocked on <see cref="WaitOnConnect"/>. + /// </remarks> + public ConnectionState State + { + get + { + return this._state; + } + + protected set + { + this._state = value; + this.SetState(value); + } + } + + protected ConnectionState _state; + protected virtual void SetState(ConnectionState state) { } + + /// <summary> + /// Constructor that initializes the ConnecitonStatistics object. + /// </summary> + /// <remarks> + /// This constructor initialises <see cref="Statistics"/> with empty statistics and sets <see cref="State"/> to + /// <see cref="ConnectionState.NotConnected"/>. + /// </remarks> + protected Connection() + { + this.Statistics = new ConnectionStatistics(); + this.State = ConnectionState.NotConnected; + } + + /// <summary> + /// Sends a number of bytes to the end point of the connection using the specified <see cref="SendOption"/>. + /// </summary> + /// <param name="msg">The message to send.</param> + public abstract SendErrors Send(MessageWriter msg); + + /// <summary> + /// Connects the connection to a server and begins listening. + /// This method blocks and may thrown if there is a problem connecting. + /// </summary> + /// <param name="bytes">The bytes of data to send in the handshake.</param> + /// <param name="timeout">The number of milliseconds to wait before giving up on the connect attempt.</param> + public abstract void Connect(byte[] bytes = null, int timeout = 5000); + + /// <summary> + /// Connects the connection to a server and begins listening. + /// This method does not block. + /// </summary> + /// <param name="bytes">The bytes of data to send in the handshake.</param> + public abstract void ConnectAsync(byte[] bytes = null); + + /// <summary> + /// Invokes the DataReceived event. + /// </summary> + /// <param name="msg">The bytes received.</param> + /// <param name="sendOption">The <see cref="SendOption"/> the message was received with.</param> + /// <remarks> + /// Invokes the <see cref="DataReceived"/> event on this connection to alert subscribers a new message has been + /// received. The bytes and the send option that the message was sent with should be passed in to give to the + /// subscribers. + /// </remarks> + protected void InvokeDataReceived(MessageReader msg, SendOption sendOption) + { + // Make a copy to avoid race condition between null check and invocation + Action<DataReceivedEventArgs> handler = DataReceived; + if (handler != null) + { + try + { + handler(new DataReceivedEventArgs(this, msg, sendOption)); + } + catch { } + } + else + { + msg.Recycle(); + } + } + + /// <summary> + /// Invokes the Disconnected event. + /// </summary> + /// <param name="e">The exception, if any, that occurred to cause this.</param> + /// <param name="reader">Extra disconnect data</param> + /// <remarks> + /// Invokes the <see cref="Disconnected"/> event to alert subscribres this connection has been disconnected either + /// by the end point or because an error occurred. If an error occurred the error should be passed in in order to + /// pass to the subscribers, otherwise null can be passed in. + /// </remarks> + protected void InvokeDisconnected(string e, MessageReader reader) + { + // Make a copy to avoid race condition between null check and invocation + EventHandler<DisconnectedEventArgs> handler = Disconnected; + if (handler != null) + { + DisconnectedEventArgs args = new DisconnectedEventArgs(e, reader); + try + { + handler(this, args); + } + catch + { + } + } + } + + /// <summary> + /// For times when you want to force the disconnect handler to fire as well as close it. + /// If you only want to close it, just use Dispose. + /// </summary> + public abstract void Disconnect(string reason, MessageWriter writer = null); + + /// <summary> + /// Disposes of this NetworkConnection. + /// </summary> + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// <summary> + /// Disposes of this NetworkConnection. + /// </summary> + /// <param name="disposing">Are we currently disposing?</param> + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + this.DataReceived = null; + this.Disconnected = null; + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/ConnectionListener.cs b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs new file mode 100644 index 0000000..f952847 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs @@ -0,0 +1,160 @@ +using System; +using System.Net; + +namespace Hazel +{ + /// <summary> + /// Base class for all connection listeners. + /// </summary> + /// <remarks> + /// <para> + /// ConnectionListeners are server side objects that listen for clients and create matching server side connections + /// for each client in a similar way to TCP does. These connections should be ready for communication immediately. + /// </para> + /// <para> + /// Each time a client connects the <see cref="NewConnection"/> event will be invoked to alert all subscribers to + /// the new connection. A disconnected event is then present on the <see cref="Connection"/> that is passed to the + /// subscribers. + /// </para> + /// </remarks> + /// <threadsafety static="true" instance="true"/> + public abstract class ConnectionListener : IDisposable + { + /// <summary> + /// The max size Hazel attempts to read from the network. + /// Defaults to 8096. + /// </summary> + /// <remarks> + /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo. + /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5 + /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant + /// for transferring large contiguous blocks of data, so... please don't? + /// </remarks> + public int ReceiveBufferSize = 8096; + + public readonly ListenerStatistics Statistics = new ListenerStatistics(); + + public abstract double AveragePing { get; } + public abstract int ConnectionCount { get; } + public abstract int SendQueueLength { get; } + public abstract int ReceiveQueueLength { get; } + + /// <summary> + /// A callback for early connection rejection. + /// * Return false to reject connection. + /// * A null response is ok, we just won't send anything. + /// </summary> + public AcceptConnectionCheck AcceptConnection; + public delegate bool AcceptConnectionCheck(IPEndPoint endPoint, byte[] input, out byte[] response); + + /// <summary> + /// Invoked when a new client connects. + /// </summary> + /// <remarks> + /// <para> + /// NewConnection is invoked each time a client connects to the listener. The + /// <see cref="NewConnectionEventArgs"/> contains the new <see cref="Connection"/> for communication with this + /// client. + /// </para> + /// <para> + /// Hazel may or may not store connections so it is your responsibility to keep track and properly Dispose of + /// connections to your server. + /// </para> + /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" /> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/> + /// </example> + public event Action<NewConnectionEventArgs> NewConnection; + + /// <summary> + /// Invoked when an internal error causes the listener to be unable to continue handling messages. + /// </summary> + /// <remarks> + /// Support for this is still pretty limited. At the time of writing, only iOS devices need this in one case: + /// When iOS suspends an app, it might also free our socket while not allowing Unity to run in the background. + /// When Unity resumes, it can't know that time passed or the socket is freed, so we used to continuously throw internal errors. + /// </remarks> + public event Action<HazelInternalErrors> OnInternalError; + + /// <summary> + /// Makes this connection listener begin listening for connections. + /// </summary> + /// <remarks> + /// <para> + /// This instructs the listener to begin listening for new clients connecting to the server. When a new client + /// connects the <see cref="NewConnection"/> event will be invoked containing the connection to the new client. + /// </para> + /// <para> + /// To stop listening you should call <see cref="Dispose()"/>. + /// </para> + /// </remarks> + /// <example> + /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/> + /// </example> + public abstract void Start(); + + /// <summary> + /// Invokes the NewConnection event with the supplied connection. + /// </summary> + /// <param name="msg">The user sent bytes that were received as part of the handshake.</param> + /// <param name="connection">The connection to pass in the arguments.</param> + /// <remarks> + /// Implementers should call this to invoke the <see cref="NewConnection"/> event before data is received so that + /// subscribers do not miss any data that may have been sent immediately after connecting. + /// </remarks> + protected void InvokeNewConnection(MessageReader msg, Connection connection) + { + // Make a copy to avoid race condition between null check and invocation + Action<NewConnectionEventArgs> handler = NewConnection; + if (handler != null) + { + try + { + handler(new NewConnectionEventArgs(msg, connection)); + } + catch (Exception e) + { + } + } + } + + + /// <summary> + /// Invokes the InternalError event with the supplied reason. + /// </summary> + protected void InvokeInternalError(HazelInternalErrors reason) + { + // Make a copy to avoid race condition between null check and invocation + Action<HazelInternalErrors> handler = this.OnInternalError; + if (handler != null) + { + try + { + handler(reason); + } + catch + { + } + } + } + + /// <summary> + /// Call to dispose of the connection listener. + /// </summary> + public void Dispose() + { + Dispose(true); + } + + /// <summary> + /// Called when the object is being disposed. + /// </summary> + /// <param name="disposing">Are we disposing?</param> + protected virtual void Dispose(bool disposing) + { + this.NewConnection = null; + this.OnInternalError = null; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/ConnectionState.cs b/Tools/Hazel-Networking/Hazel/ConnectionState.cs new file mode 100644 index 0000000..5d3f5c9 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ConnectionState.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + /// <summary> + /// Represents the state a <see cref="Connection"/> is currently in. + /// </summary> + public enum ConnectionState + { + /// <summary> + /// The Connection has either not been established yet or has been disconnected. + /// </summary> + NotConnected, + + /// <summary> + /// The Connection is currently connecting to an endpoint. + /// </summary> + Connecting, + + /// <summary> + /// The Connection is connected and data can be transfered. + /// </summary> + Connected, + } +} diff --git a/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs new file mode 100644 index 0000000..f2c3ed9 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs @@ -0,0 +1,574 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; + + +namespace Hazel +{ + /// <summary> + /// Holds statistics about the traffic through a <see cref="Connection"/>. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public class ConnectionStatistics + { + private const int ExpectedMTU = 1200; + + /// <summary> + /// The total number of messages sent. + /// </summary> + public int MessagesSent + { + get + { + return UnreliableMessagesSent + ReliableMessagesSent + FragmentedMessagesSent + AcknowledgementMessagesSent + HelloMessagesSent; + } + } + + private int packetsSent; + public int PacketsSent => this.packetsSent; + + private int reliablePacketsAcknowledged; + public int ReliablePacketsAcknowledged => this.reliablePacketsAcknowledged; + + /// <summary> + /// The number of messages sent larger than 576 bytes. This is smaller than most default MTUs. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int FragmentableMessagesSent + { + get + { + return fragmentableMessagesSent; + } + } + + /// <summary> + /// The number of messages sent larger than 576 bytes. + /// </summary> + int fragmentableMessagesSent; + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int UnreliableMessagesSent + { + get + { + return unreliableMessagesSent; + } + } + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + int unreliableMessagesSent; + + /// <summary> + /// The number of reliable messages sent. + /// </summary> + /// <remarks> + /// This is the number of reliable messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogReliableSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int ReliableMessagesSent + { + get + { + return reliableMessagesSent; + } + } + + /// <summary> + /// The number of unreliable messages sent. + /// </summary> + int reliableMessagesSent; + + /// <summary> + /// The number of fragmented messages sent. + /// </summary> + /// <remarks> + /// This is the number of fragmented messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogFragmentedSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int FragmentedMessagesSent + { + get + { + return fragmentedMessagesSent; + } + } + + /// <summary> + /// The number of fragmented messages sent. + /// </summary> + int fragmentedMessagesSent; + + /// <summary> + /// The number of acknowledgement messages sent. + /// </summary> + /// <remarks> + /// This is the number of acknowledgements that were sent from the <see cref="Connection"/>, incremented + /// each time that LogAcknowledgementSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int AcknowledgementMessagesSent + { + get + { + return acknowledgementMessagesSent; + } + } + + /// <summary> + /// The number of acknowledgement messages sent. + /// </summary> + int acknowledgementMessagesSent; + + /// <summary> + /// The number of hello messages sent. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were sent from the <see cref="Connection"/>, incremented + /// each time that LogHelloSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </remarks> + public int HelloMessagesSent + { + get + { + return helloMessagesSent; + } + } + + /// <summary> + /// The number of hello messages sent. + /// </summary> + int helloMessagesSent; + + /// <summary> + /// The number of bytes of data sent. + /// </summary> + /// <remarks> + /// <para> + /// This is the number of bytes of data (i.e. user bytes) that were sent from the <see cref="Connection"/>, + /// accumulated each time that LogSend is called by the Connection. Messages that caused an error are not + /// counted and messages are only counted once all other operations in the send are complete. + /// </para> + /// <para> + /// For the number of bytes including protocol bytes see <see cref="TotalBytesSent"/>. + /// </para> + /// </remarks> + public long DataBytesSent + { + get + { + return Interlocked.Read(ref dataBytesSent); + } + } + + /// <summary> + /// The number of bytes of data sent. + /// </summary> + long dataBytesSent; + + /// <summary> + /// The number of bytes sent in total. + /// </summary> + /// <remarks> + /// <para> + /// This is the total number of bytes (the data bytes plus protocol bytes) that were sent from the + /// <see cref="Connection"/>, accumulated each time that LogSend is called by the Connection. Messages that + /// caused an error are not counted and messages are only counted once all other operations in the send are + /// complete. + /// </para> + /// <para> + /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesSent"/>. + /// </para> + /// </remarks> + public long TotalBytesSent + { + get + { + return Interlocked.Read(ref totalBytesSent); + } + } + + /// <summary> + /// The number of bytes sent in total. + /// </summary> + long totalBytesSent; + + /// <summary> + /// The total number of messages received. + /// </summary> + public int MessagesReceived + { + get + { + return UnreliableMessagesReceived + ReliableMessagesReceived + FragmentedMessagesReceived + AcknowledgementMessagesReceived + helloMessagesReceived; + } + } + + /// <summary> + /// The number of unreliable messages received. + /// </summary> + /// <remarks> + /// This is the number of unreliable messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogUnreliableReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int UnreliableMessagesReceived + { + get + { + return unreliableMessagesReceived; + } + } + + /// <summary> + /// The number of unreliable messages received. + /// </summary> + int unreliableMessagesReceived; + + /// <summary> + /// The number of reliable messages received. + /// </summary> + /// <remarks> + /// This is the number of reliable messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogReliableReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int ReliableMessagesReceived + { + get + { + return reliableMessagesReceived; + } + } + + /// <summary> + /// The number of reliable messages received. + /// </summary> + int reliableMessagesReceived; + + /// <summary> + /// The number of fragmented messages received. + /// </summary> + /// <remarks> + /// This is the number of fragmented messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogFragmentedReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int FragmentedMessagesReceived + { + get + { + return fragmentedMessagesReceived; + } + } + + /// <summary> + /// The number of fragmented messages received. + /// </summary> + int fragmentedMessagesReceived; + + /// <summary> + /// The number of acknowledgement messages received. + /// </summary> + /// <remarks> + /// This is the number of acknowledgement messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogAcknowledgemntReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int AcknowledgementMessagesReceived + { + get + { + return acknowledgementMessagesReceived; + } + } + + /// <summary> + /// The number of acknowledgement messages received. + /// </summary> + int acknowledgementMessagesReceived; + + /// <summary> + /// The number of ping messages received. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int PingMessagesReceived + { + get + { + return pingMessagesReceived; + } + } + + /// <summary> + /// The number of hello messages received. + /// </summary> + int pingMessagesReceived; + + /// <summary> + /// The number of hello messages received. + /// </summary> + /// <remarks> + /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented + /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked. + /// </remarks> + public int HelloMessagesReceived + { + get + { + return helloMessagesReceived; + } + } + + /// <summary> + /// The number of hello messages received. + /// </summary> + int helloMessagesReceived; + + /// <summary> + /// The number of bytes of data received. + /// </summary> + /// <remarks> + /// <para> + /// This is the number of bytes of data (i.e. user bytes) that were received by the <see cref="Connection"/>, + /// accumulated each time that LogReceive is called by the Connection. Messages are counted before the receive + /// event is invoked. + /// </para> + /// <para> + /// For the number of bytes including protocol bytes see <see cref="TotalBytesReceived"/>. + /// </para> + /// </remarks> + public long DataBytesReceived + { + get + { + return Interlocked.Read(ref dataBytesReceived); + } + } + + /// <summary> + /// The number of bytes of data received. + /// </summary> + long dataBytesReceived; + + /// <summary> + /// The number of bytes received in total. + /// </summary> + /// <remarks> + /// <para> + /// This is the total number of bytes (the data bytes plus protocol bytes) that were received by the + /// <see cref="Connection"/>, accumulated each time that LogReceive is called by the Connection. Messages are + /// counted before the receive event is invoked. + /// </para> + /// <para> + /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesReceived"/>. + /// </para> + /// </remarks> + public long TotalBytesReceived + { + get + { + return Interlocked.Read(ref totalBytesReceived); + } + } + + /// <summary> + /// The number of bytes received in total. + /// </summary> + long totalBytesReceived; + + public int MessagesResent { get { return messagesResent; } } + int messagesResent; + + /// <summary> + /// Logs the sending of an unreliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogUnreliableSend(int dataLength) + { + Interlocked.Increment(ref unreliableMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + + } + + /// <param name="totalLength">The total number of bytes sent.</param> + internal void LogPacketSend(int totalLength) + { + Interlocked.Increment(ref this.packetsSent); + Interlocked.Add(ref totalBytesSent, totalLength); + + if (totalLength > ExpectedMTU) + { + Interlocked.Increment(ref fragmentableMessagesSent); + } + } + + /// <summary> + /// Logs the sending of a reliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogReliableSend(int dataLength) + { + Interlocked.Increment(ref reliableMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + } + + /// <summary> + /// Logs the sending of a fragmented data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data sent.</param> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogFragmentedSend(int dataLength) + { + Interlocked.Increment(ref fragmentedMessagesSent); + Interlocked.Add(ref dataBytesSent, dataLength); + } + + /// <summary> + /// Logs the sending of a acknowledgement data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogAcknowledgementSend() + { + Interlocked.Increment(ref acknowledgementMessagesSent); + } + + /// <summary> + /// Logs the sending of a hellp data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes sent.</param> + /// <remarks> + /// This should be called after the data has been sent and should only be called for data that is sent sucessfully. + /// </remarks> + internal void LogHelloSend() + { + Interlocked.Increment(ref helloMessagesSent); + } + + /// <summary> + /// Logs the receiving of an unreliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogUnreliableReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref unreliableMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a reliable data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogReliableReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref reliableMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a fragmented data packet in the statistics. + /// </summary> + /// <param name="dataLength">The number of bytes of data received.</param> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogFragmentedReceive(int dataLength, int totalLength) + { + Interlocked.Increment(ref fragmentedMessagesReceived); + Interlocked.Add(ref dataBytesReceived, dataLength); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of an acknowledgement data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogAcknowledgementReceive(int totalLength) + { + Interlocked.Increment(ref acknowledgementMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the unique acknowledgement of a ping or reliable data packet. + /// </summary> + internal void LogReliablePacketAcknowledged() + { + Interlocked.Increment(ref this.reliablePacketsAcknowledged); + } + + /// <summary> + /// Logs the receiving of a hello data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogPingReceive(int totalLength) + { + Interlocked.Increment(ref pingMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + /// <summary> + /// Logs the receiving of a hello data packet in the statistics. + /// </summary> + /// <param name="totalLength">The total number of bytes received.</param> + /// <remarks> + /// This should be called before the received event is invoked so it is up to date for subscribers to that event. + /// </remarks> + internal void LogHelloReceive(int totalLength) + { + Interlocked.Increment(ref helloMessagesReceived); + Interlocked.Add(ref totalBytesReceived, totalLength); + } + + internal void LogMessageResent() + { + Interlocked.Increment(ref messagesResent); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs new file mode 100644 index 0000000..bfbbc01 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs @@ -0,0 +1,369 @@ +using System; +using System.Diagnostics; +using System.Security.Cryptography; + +namespace Hazel.Crypto +{ + /// <summary> + /// Implementation of AEAD_AES128_GCM based on: + /// * RFC 5116 [1] + /// * NIST SP 800-38d [2] + /// + /// [1] https://tools.ietf.org/html/rfc5116 + /// [2] https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf + /// + /// Adapted from: https://gist.github.com/mendsley/777e6bd9ae7eddcb2b0c0fe18247dc60 + /// </summary> + public class Aes128Gcm : IDisposable + { + public const int KeySize = 16; + public const int NonceSize = 12; + public const int CiphertextOverhead = TagSize; + + private const int TagSize = 16; + + private readonly IAes encryptor_; + + private readonly ByteSpan hashSubkey_; + private readonly ByteSpan blockJ_; + private readonly ByteSpan blockS_; + private readonly ByteSpan blockZ_; + private readonly ByteSpan blockV_; + private readonly ByteSpan blockScratch_; + + /// <summary> + /// Creates a new instance of an AEAD_AES128_GCM cipher + /// </summary> + /// <param name="key">Symmetric key</param> + public Aes128Gcm(ByteSpan key) + { + if (key.Length != KeySize) + { + throw new ArgumentException("Invalid key length", nameof(key)); + } + + // Create the AES block cipher + this.encryptor_ = CryptoProvider.CreateAes(key); + + // Allocate scratch space + ByteSpan scratchSpace = new byte[96]; + this.hashSubkey_ = scratchSpace.Slice(0, 16); + this.blockJ_ = scratchSpace.Slice(16, 16); + this.blockS_ = scratchSpace.Slice(32, 16); + this.blockZ_ = scratchSpace.Slice(48, 16); + this.blockV_ = scratchSpace.Slice(64, 16); + this.blockScratch_ = scratchSpace.Slice(80, 16); + + // Create the GHASH subkey by encrypting the 0-block + this.encryptor_.EncryptBlock(this.hashSubkey_, this.hashSubkey_); + } + + /// <summary> + /// Encryptes the specified plaintext and generates an authentication + /// tag for the provided additional data. Returns the byte array + /// containg both the ciphertext and authentication tag. + /// </summary> + /// <param name="output"> + /// Array in which to encode the encrypted ciphertext and + /// authentication tag. This array must be large enough to hold + /// `plaintext.Lengh + CiphertextOverhead` bytes. + /// </param> + /// <param name="nonce">Unique value for this message</param> + /// <param name="plaintext">Plaintext data to encrypt</param> + /// <param name="associatedData"> + /// Additional data used to authenticate the message + /// </param> + public void Seal(ByteSpan output, ByteSpan nonce, ByteSpan plaintext, ByteSpan associatedData) + { + if (nonce.Length != NonceSize) + { + throw new ArgumentException("Invalid nonce size", nameof(nonce)); + } + if (output.Length < plaintext.Length + CiphertextOverhead) + { + throw new ArgumentException("Invalid output size", nameof(output)); + } + + // Create the initial counter block + nonce.CopyTo(this.blockJ_); + + // Encrypt the plaintext to output + GCTR(output, this.blockJ_, 2, plaintext); + + // Generate and append the authentication tag + int tagOffset = plaintext.Length; + GenerateAuthenticationTag(output.Slice(tagOffset), output.Slice(0, tagOffset), associatedData); + } + + /// <summary> + /// Validates the authentication tag against the provided additional + /// data, then decrypts the cipher text returning the original + /// plaintext. + /// </summary> + /// <param name="nonce"> + /// The unique value used to seal this message + /// </param> + /// <param name="ciphertext"> + /// Combined ciphertext and authentication tag + /// </param> + /// <param name="associatedData"> + /// Additional data used to authenticate the message + /// </param> + /// <param name="output"> + /// On successful validation and decryprion, Open writes the original + /// plaintext to output. Must contain enough space to hold + /// `ciphertext.Length - CiphertextOverhead` bytes. + /// </param> + /// <returns> + /// True if the data was validated and successfully decrypted. + /// Otherwise, false. + /// </returns> + public bool Open(ByteSpan output, ByteSpan nonce, ByteSpan ciphertext, ByteSpan associatedData) + { + if (nonce.Length != NonceSize) + { + throw new ArgumentException("Invalid nonce size", nameof(nonce)); + } + if (ciphertext.Length < CiphertextOverhead) + { + throw new ArgumentException("Invalid ciphertext size", nameof(ciphertext)); + } + else if (output.Length < ciphertext.Length - CiphertextOverhead) + { + throw new ArgumentException("Invalid output size", nameof(output)); + } + + // Split ciphertext into actual ciphertext and authentication + // tag components. + ByteSpan authenticationTag = ciphertext.Slice(ciphertext.Length - TagSize); + ciphertext = ciphertext.Slice(0, ciphertext.Length - TagSize); + + // Create the initial counter block + nonce.CopyTo(this.blockJ_); + + // Verify the tags match + GenerateAuthenticationTag(this.blockScratch_, ciphertext, associatedData); + if (0 == Const.ConstantCompareSpans(this.blockScratch_, authenticationTag)) + { + return false; + } + + // Decrypt the cipher text to output + GCTR(output, this.blockJ_, 2, ciphertext); + return true; + } + + /// <summary> + /// Release resources acquired by the cipher + /// </summary> + public void Dispose() + { + this.encryptor_.Dispose(); + } + + // Generate the authentication tag for a ciphertext+associated data + void GenerateAuthenticationTag(ByteSpan output, ByteSpan ciphertext, ByteSpan associatedData) + { + Debug.Assert(output.Length >= 16); + + // Hash `Associated data || Ciphertext || len(AssociatedD data) || len(Ciphertext)` + // into `blockS` + { + // Clear hash output block + SetSpanToZeros(this.blockS_); + + // Write associated data blocks to hash + int fullBlocks = associatedData.Length / 16; + GHASH(this.blockS_, associatedData, fullBlocks); + if (fullBlocks * 16 < associatedData.Length) + { + SetSpanToZeros(this.blockScratch_); + associatedData.Slice(fullBlocks * 16).CopyTo(this.blockScratch_); + GHASH(this.blockS_, this.blockScratch_, 1); + } + + // Write ciphertext blocks to hash + fullBlocks = ciphertext.Length / 16; + GHASH(this.blockS_, ciphertext, fullBlocks); + if (fullBlocks * 16 < ciphertext.Length) + { + SetSpanToZeros(this.blockScratch_); + ciphertext.Slice(fullBlocks * 16).CopyTo(this.blockScratch_); + GHASH(this.blockS_, this.blockScratch_, 1); + } + + // Write bit sizes to hash + ulong associatedDataLengthInBits = (ulong)(8 * associatedData.Length); + ulong ciphertextDataLengthInBits = (ulong)(8 * ciphertext.Length); + this.blockScratch_.WriteBigEndian64(associatedDataLengthInBits); + this.blockScratch_.WriteBigEndian64(ciphertextDataLengthInBits, 8); + + GHASH(this.blockS_, this.blockScratch_, 1); + } + + // Encrypt the tag. GCM requires this because `GASH` is not + // cryptographically secure. An attacker could derive our hash + // subkey `hashSubkey_` from an unencrypted tag. + GCTR(output, this.blockJ_, 1, this.blockS_); + } + + // Run the GCTR cipher + void GCTR(ByteSpan output, ByteSpan counterBlock, uint counter, ByteSpan data) + { + Debug.Assert(counterBlock.Length == 16); + Debug.Assert(output.Length >= data.Length); + + // Loop through plaintext blocks + int writeIndex = 0; + int numBlocks = (data.Length + 15) / 16; + for (int ii = 0; ii != numBlocks; ++ii) + { + // Encode counter into block + // CB[1] = J0 + // CB[i] = inc[32](CB[i-1]) + counterBlock.WriteBigEndian32(counter, 12); + ++counter; + + // CIPH[k](CB[i]) + this.encryptor_.EncryptBlock(counterBlock.Slice(0, 16), this.blockScratch_); + + // Y[i] = X[i] xor CIPH[k](CB[i]) + for (int jj = 0; jj != 16 && writeIndex < data.Length; ++jj, ++writeIndex) + { + output[writeIndex] = (byte)(data[writeIndex] ^ this.blockScratch_[jj]); + } + } + } + + // Run the GHASH function + void GHASH(ByteSpan output, ByteSpan data, int numBlocks) + { + ///TODO(mendsley): See Ref[6] for opitmizations of GHASH on both hardware and software + /// + ///[6] D. McGrew, J. Viega, The Galois/Counter Mode of Operation (GCM), Natl. Inst. Stand. + ///Technol. [Web page], http://www.csrc.nist.gov/groups/ST/toolkit/BCM/documents/ + ///proposedmodes / gcm / gcm - revised - spec.pdf, May 31, 2005. + + Debug.Assert(output.Length == 16); + Debug.Assert(data.Length >= numBlocks * 16); + + int readIndex = 0; + for (int ii = 0; ii != numBlocks; ++ii) + { + for (int jj = 0; jj != 16; ++jj, ++readIndex) + { + // Y[ii-1] xor X[ii] + output[jj] ^= data[readIndex]; + } + + // Y[ii] = (Y[ii-1] xor X[ii]) · H + MultiplyGF128Elements(output, this.hashSubkey_, this.blockZ_, this.blockV_); + } + } + + // Multiply two Galois field elements `X` and `Y` together and store + // the result in `X` such that at the end of the function: + // X = X·Y + static void MultiplyGF128Elements(ByteSpan X, ByteSpan Y, ByteSpan scratchZ, ByteSpan scratchV) + { + Debug.Assert(X.Length == 16); + Debug.Assert(Y.Length == 16); + Debug.Assert(scratchZ.Length == 16); + Debug.Assert(scratchV.Length == 16); + + // Galois (finite) fields represented by GF(p) define a set of + // closed algebraic operations. For AES128_GCM we'll be dealing + // with the GF(2^128) field. + // + // We treat each incoming 16 byte block as a polynomial in field + // and define multiplication between two polynomials as the + // polynomial product reduced by (mod) the field polynomial: + // 1 + x + x^2 + x^7 + x^128 + // + // Field polynomials are represented by a 128 bit string. Bit n is + // the coefficient of the x^n term. We use little-endian bit + // ordering (not to be confused with byte ordering) for these + // coefficients. E.g. X[0] & 0x00000001 represents the 7th bit in + // the bit string defined by X, _not_ the 0th bit. + // + + // What follows is a modified version of the "peasant's algorithm" + // to multiply two numbers: + // + // Z contains the accumulated product + // V is a copy of Y (so we can modify it via shifting). + // + // We calculate Z = X·V as follows + // We loop through each of the 128 bits in X maintaining the + // following loop invariant: X·V + Z = the final product + // + // On each iteration `ii`: + // + // If the `ii`th bit of `X` is set, add the add the polynomial + // in `V` to `X`: `X[n] = X[n] ^ V[n]` + // + // Double V (Shift one bit right since we're storing little + // endian bit). This has the effect of multiplying V by the + // polynomial `x`. We track the unrepresentable coefficient + // of `x^128` by storing the most significant bit before the + // shift `V[15] >> 7` as `carry` + // + // Check if we've overflowed our multiplication. If overflow + // occurred, there will be a non-zero coefficient for the + // `x^128` term in the step above `carry` + // + // If we have overflowed, our polynomial is exactly of degree + // 129 (since we're only multiplying by `x`). We reduce the + // polynomial back into degree 128 by adding our field's + // irreducible polynomial: 1 + x + x^2 + x^7 + x^128. This + // reduction cancels out the x^128 term (x^128 + x^128 in GF(2) + // is zero). Therefore this modulo can be achieved by simply + // adding the irreducible polynomial to the new value of `V`. The + // irreducible polynomial is represented by the bit string: + // `11100001` followed by 120 `0`s. We can add this value to `V` + // by: `V[0] = V[0] ^ 0xE1`. + SetSpanToZeros(scratchZ); + X.CopyTo(scratchV); + + for (int ii = 0; ii != 128; ++ii) + { + int bitIndex = 7 - (ii % 8); + if ((Y[ii / 8] & (1 << bitIndex)) != 0) + { + for (int jj = 0; jj != 16; ++jj) + { + scratchZ[jj] ^= scratchV[jj]; + } + } + + bool carry = false; + for (int jj = 0; jj != 16; ++jj) + { + bool newCarry = (scratchV[jj] & 0x01) != 0; + scratchV[jj] >>= 1; + if (carry) + { + scratchV[jj] |= 0x80; + } + carry = newCarry; + } + + if (carry) + { + scratchV[0] ^= 0xE1; + } + } + + scratchZ.CopyTo(X); + } + + // Set the contents of a span to all zero + static void SetSpanToZeros(ByteSpan span) + { + for (int ii = 0, nn = span.Length; ii != nn; ++ii) + { + span[ii] = 0; + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Const.cs b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs new file mode 100644 index 0000000..4dfef47 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs @@ -0,0 +1,82 @@ +using System.Diagnostics; + +namespace Hazel.Crypto +{ + public static class Const + { + + /// <summary> + /// Compare two bytes for equality. + /// + /// This takes care to always use a constant amount of time to prevent + /// leaking information through side-channel attacks. + /// + /// This is aceived by collapsing the xor bits down into a single bit. + /// + /// Ported from: + /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h + /// </summary> + /// <returns> + /// Returns `1` is the two bytes or equivalent. Otherwise, returns `0` + /// </returns> + public static byte ConstantCompareByte(byte a, byte b) + { + byte result = (byte)(~(a ^ b)); + + // collapse bits down to the LSB + result &= (byte)(result >> 4); + result &= (byte)(result >> 2); + result &= (byte)(result >> 1); + + return result; + } + + /// <summary> + /// Compare two equal length spans for equality. + /// + /// This takes care to always use a constant amount of time to prevent + /// leaking information through side-channel attacks. + /// + /// Ported from: + /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h + /// </summary> + /// <returns> + /// Returns `1` if the spans are equivalent. Others, returns `0`. + /// </returns> + public static byte ConstantCompareSpans(ByteSpan a, ByteSpan b) + { + Debug.Assert(a.Length == b.Length); + + byte value = 0; + for (int ii = 0, nn = a.Length; ii != nn; ++ii) + { + value |= (byte)(a[ii] ^ b[ii]); + } + + return ConstantCompareByte(value, 0); + } + + /// <summary> + /// Compare a span against an all zero span + /// + /// This takes care to always use a constant amount of time to prevent + /// leaking information through side-channel attacks. + /// + /// Ported from: + /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h + /// </summary> + /// <returns> + /// Returns `1` if the spans is all zeros. Others, returns `0`. + /// </returns> + public static byte ConstantCompareZeroSpan(ByteSpan a) + { + byte value = 0; + for (int ii = 0, nn = a.Length; ii != nn; ++ii) + { + value |= (byte)(a[ii] ^ 0); + } + + return ConstantCompareByte(value, 0); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs new file mode 100644 index 0000000..2c56c70 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Hazel.Crypto +{ + public static class CryptoProvider + { + public delegate IAes CreateAesOverrideDelegate(ByteSpan key); + + /// <summary> + /// Override the default AES creation function + /// </summary> + public static CreateAesOverrideDelegate OverrideCreateAes = null; + + /// <summary> + /// Create a new AES cipher + /// </summary> + /// <param name="key">Encrtyption key</param> + public static IAes CreateAes(ByteSpan key) + { + if (OverrideCreateAes != null) + { + IAes result = OverrideCreateAes(key); + if (null != result) + { + return result; + } + } + + return new DefaultAes(key); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs new file mode 100644 index 0000000..da72fb8 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs @@ -0,0 +1,49 @@ +using System; +using System.Security.Cryptography; + +namespace Hazel.Crypto +{ + /// <summary> + /// AES provider using the default System.Security.Cryptography implementation + /// </summary> + public class DefaultAes : IAes + { + private readonly ICryptoTransform encryptor_; + + /// <summary> + /// Create a new default instance of the AES block cipher + /// </summary> + /// <param name="key">Encryption key</param> + public DefaultAes(ByteSpan key) + { + // Create the AES block cipher + using (Aes aes = Aes.Create()) + { + aes.KeySize = key.Length * 8; + aes.BlockSize = aes.KeySize; + aes.Mode = CipherMode.ECB; + aes.Padding = PaddingMode.Zeros; + aes.Key = key.ToArray(); + + this.encryptor_ = aes.CreateEncryptor(); + } + } + + /// <inheritdoc/> + public void Dispose() + { + this.encryptor_.Dispose(); + } + + /// <inheritdoc/> + public int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan) + { + if (inputSpan.Length != outputSpan.Length) + { + throw new ArgumentException($"ouputSpan length ({outputSpan.Length}) does not match inputSpan length ({inputSpan.Length})", nameof(outputSpan)); + } + + return this.encryptor_.TransformBlock(inputSpan.GetUnderlyingArray(), inputSpan.Offset, inputSpan.Length, outputSpan.GetUnderlyingArray(), outputSpan.Offset); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs new file mode 100644 index 0000000..6c494cd --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Hazel.Crypto +{ + /// <summary> + /// AES encryption interface + /// </summary> + public interface IAes : IDisposable + { + /// <summary> + /// Encrypts the specified region of the input byte array and copies + /// the resulting transform to the specified region of the output + /// array. + /// </summary> + /// <param name="inputSpan">The input for which to encrypt</param> + /// <param name="outputSpan"> + /// The otput to which to write the encrypted data. This span can + /// overlap with `inputSpan`. + /// </param> + /// <returns>The number of bytes written</returns> + int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan); + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs new file mode 100644 index 0000000..1903693 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs @@ -0,0 +1,86 @@ +using System; +using System.Security.Cryptography; + +namespace Hazel.Crypto +{ + /// <summary> + /// Streams data into a SHA256 digest + /// </summary> + public class Sha256Stream : IDisposable + { + /// <summary> + /// Size of the SHA256 digest in bytes + /// </summary> + public const int DigestSize = 32; + + private SHA256 hash = SHA256.Create(); + private bool isHashFinished = false; + + struct EmptyArray + { + public static readonly byte[] Value = new byte[0]; + } + + /// <summary> + /// Create a new instance of a SHA256 stream + /// </summary> + public Sha256Stream() + { + } + + /// <summary> + /// Release resources associated with the stream + /// </summary> + public void Dispose() + { + this.hash?.Dispose(); + this.hash = null; + + GC.SuppressFinalize(this); + } + + /// <summary> + /// Reset the stream to its initial state + /// </summary> + public void Reset() + { + this.hash?.Dispose(); + this.hash = SHA256.Create(); + this.isHashFinished = false; + } + + /// <summary> + /// Add data to the stream + /// </summary> + public void AddData(ByteSpan data) + { + while (data.Length > 0) + { + int offset = this.hash.TransformBlock(data.GetUnderlyingArray(), data.Offset, data.Length, null, 0); + data = data.Slice(offset); + } + } + + /// <summary> + /// Calculate the final hash of the stream data + /// </summary> + /// <param name="output"> + /// Target span to which the hash will be written + /// </param> + public void CopyOrCalculateFinalHash(ByteSpan output) + { + if (output.Length != DigestSize) + { + throw new ArgumentException($"Expected a span of {DigestSize} bytes. Got a span of {output.Length} bytes", nameof(output)); + } + + if (this.isHashFinished == false) + { + this.hash.TransformFinalBlock(EmptyArray.Value, 0, 0); + this.isHashFinished = true; + } + + new ByteSpan(this.hash.Hash).CopyTo(output); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs new file mode 100644 index 0000000..03164ec --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs @@ -0,0 +1,36 @@ +using System; +using System.Security.Cryptography; + +namespace Hazel.Crypto +{ + public static class SpanCryptoExtensions + { + /// <summary> + /// Clear a span's contents to zero + /// </summary> + public static void SecureClear(this ByteSpan span) + { + if (span.Length > 0) + { + Array.Clear(span.GetUnderlyingArray(), span.Offset, span.Length); + } + } + + /// <summary> + /// Fill a byte span with random data + /// </summary> + /// <param name="random">Entropy source</param> + public static void FillWithRandom(this ByteSpan span, RandomNumberGenerator random) + { + if (span.Offset == 0 && span.Length == span.GetUnderlyingArray().Length) + { + random.GetBytes(span.GetUnderlyingArray()); + return; + } + + byte[] temp = new byte[span.Length]; + random.GetBytes(temp); + new ByteSpan(temp).CopyTo(span); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs new file mode 100644 index 0000000..3f4624b --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs @@ -0,0 +1,844 @@ +using System; +using System.Diagnostics; + +namespace Hazel.Crypto +{ + /// <summary> + /// The x25519 key agreement algorithm + /// </summary> + public static class X25519 + { + public const int KeySize = 32; + + /// <summary> + /// Element in the GF(2^255 - 19) field + /// </summary> + public partial struct FieldElement + { + public int x0, x1, x2, x3, x4; + public int x5, x6, x7, x8, x9; + }; + + private static readonly byte[] BasePoint = {9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + /// <summary> + /// Performs the core x25519 function: Multiplying an EC point by a scalar value + /// </summary> + public static bool Func(ByteSpan output, ByteSpan scalar, ByteSpan point) + { + InternalFunc(output, scalar, point); + if (Const.ConstantCompareZeroSpan(output) == 1) + { + return false; + } + + return true; + } + + /// <summary> + /// Multiplies the base x25519 point by the provided scalar value + /// </summary> + public static void Func(ByteSpan output, ByteSpan scalar) + { + InternalFunc(output, scalar, BasePoint); + } + + // The FieldElement code below is ported from the original + // public domain reference implemtation of X25519 + // by D. J. Bernstien + // + // See: https://cr.yp.to/ecdh.html + + private static void InternalFunc(ByteSpan output, ByteSpan scalar, ByteSpan point) + { + if (output.Length != KeySize) + { + throw new ArgumentException("Invalid output size", nameof(output)); + } + else if (scalar.Length != KeySize) + { + throw new ArgumentException("Invalid scalar size", nameof(scalar)); + } + else if (point.Length != KeySize) + { + throw new ArgumentException("Invalid point size", nameof(point)); + } + + // copy the scalar so we can properly mask it + ByteSpan maskedScalar = new byte[32]; + scalar.CopyTo(maskedScalar); + maskedScalar[0] &= 248; + maskedScalar[31] &= 127; + maskedScalar[31] |= 64; + + FieldElement x1 = FieldElement.FromBytes(point); + FieldElement x2 = FieldElement.One(); + FieldElement x3 = x1; + FieldElement z2 = FieldElement.Zero(); + FieldElement z3 = FieldElement.One(); + + FieldElement tmp0 = new FieldElement(); + FieldElement tmp1 = new FieldElement(); + + int swap = 0; + for (int pos = 254; pos >= 0; --pos) + { + int b = (int)maskedScalar[pos / 8] >> (int)(pos % 8); + b &= 1; + swap ^= b; + + FieldElement.ConditionalSwap(ref x2, ref x3, swap); + FieldElement.ConditionalSwap(ref z2, ref z3, swap); + swap = b; + + FieldElement.Sub(ref tmp0, ref x3, ref z3); + FieldElement.Sub(ref tmp1, ref x2, ref z2); + FieldElement.Add(ref x2, ref x2, ref z2); + FieldElement.Add(ref z2, ref x3, ref z3); + FieldElement.Multiply(ref z3, ref tmp0, ref x2); + FieldElement.Multiply(ref z2, ref z2, ref tmp1); + FieldElement.Square(ref tmp0, ref tmp1); + FieldElement.Square(ref tmp1, ref x2); + FieldElement.Add(ref x3, ref z3, ref z2); + FieldElement.Sub(ref z2, ref z3, ref z2); + FieldElement.Multiply(ref x2, ref tmp1, ref tmp0); + FieldElement.Sub(ref tmp1, ref tmp1, ref tmp0); + FieldElement.Square(ref z2, ref z2); + FieldElement.Multiply121666(ref z3, ref tmp1); + FieldElement.Square(ref x3, ref x3); + FieldElement.Add(ref tmp0, ref tmp0, ref z3); + FieldElement.Multiply(ref z3, ref x1, ref z2); + FieldElement.Multiply(ref z2, ref tmp1, ref tmp0); + } + + FieldElement.ConditionalSwap(ref x2, ref x3, swap); + FieldElement.ConditionalSwap(ref z2, ref z3, swap); + + FieldElement.Invert(ref z2, ref z2); + FieldElement.Multiply(ref x2, ref x2, ref z2); + x2.CopyTo(output); + } + + + /// <summary> + /// Mathematical operators over GF(2^255 - 19) + /// </summary> + partial struct FieldElement + { + /// <summary> + /// Convert a byte array to a field element + /// </summary> + public static FieldElement FromBytes(ByteSpan bytes) + { + Debug.Assert(bytes.Length >= KeySize); + + long tmp0 = (long)bytes.ReadLittleEndian32(); + long tmp1 = (long)bytes.ReadLittleEndian24(4) << 6; + long tmp2 = (long)bytes.ReadLittleEndian24(7) << 5; + long tmp3 = (long)bytes.ReadLittleEndian24(10) << 3; + long tmp4 = (long)bytes.ReadLittleEndian24(13) << 2; + long tmp5 = (long)bytes.ReadLittleEndian32(16); + long tmp6 = (long)bytes.ReadLittleEndian24(20) << 7; + long tmp7 = (long)bytes.ReadLittleEndian24(23) << 5; + long tmp8 = (long)bytes.ReadLittleEndian24(26) << 4; + long tmp9 = (long)(bytes.ReadLittleEndian24(29) & 0x007FFFFF) << 2; + + long carry9 = (tmp9 + (1L<<24)) >> 25; + tmp0 += carry9 * 19; + tmp9 -= carry9 << 25; + long carry1 = (tmp1 + (1L<<24)) >> 25; + tmp2 += carry1; + tmp1 -= carry1 << 25; + long carry3 = (tmp3 + (1L<<24)) >> 25; + tmp4 += carry3; + tmp3 -= carry3 << 25; + long carry5 = (tmp5 + (1L<<24)) >> 25; + tmp6 += carry5; + tmp5 -= carry5 << 25; + long carry7 = (tmp7 + (1L<<24)) >> 25; + tmp8 += carry7; + tmp7 -= carry7 << 25; + + long carry0 = (tmp0 + (1L<<25)) >> 26; + tmp1 += carry0; + tmp0 -= carry0 << 26; + long carry2 = (tmp2 + (1L<<25)) >> 26; + tmp3 += carry2; + tmp2 -= carry2 << 26; + long carry4 = (tmp4 + (1L<<25)) >> 26; + tmp5 += carry4; + tmp4 -= carry4 << 26; + long carry6 = (tmp6 + (1L<<25)) >> 26; + tmp7 += carry6; + tmp6 -= carry6 << 26; + long carry8 = (tmp8 + (1L<<25)) >> 26; + tmp9 += carry8; + tmp8 -= carry8 << 26; + + return new FieldElement + { + x0 = (int)tmp0, + x1 = (int)tmp1, + x2 = (int)tmp2, + x3 = (int)tmp3, + x4 = (int)tmp4, + x5 = (int)tmp5, + x6 = (int)tmp6, + x7 = (int)tmp7, + x8 = (int)tmp8, + x9 = (int)tmp9, + }; + } + + /// <summary> + /// Convert the field element to a byte array + /// </summary> + public void CopyTo(ByteSpan output) + { + Debug.Assert(output.Length >= 32); + + long q = (19 * this.x9 + (1L << 24)) >> 25; + q = ((long)this.x0 + q) >> 26; + q = ((long)this.x1 + q) >> 25; + q = ((long)this.x2 + q) >> 26; + q = ((long)this.x3 + q) >> 25; + q = ((long)this.x4 + q) >> 26; + q = ((long)this.x5 + q) >> 25; + q = ((long)this.x6 + q) >> 26; + q = ((long)this.x7 + q) >> 25; + q = ((long)this.x8 + q) >> 26; + q = ((long)this.x9 + q) >> 25; + + this.x0 = (int)((long)this.x0 + (19L * q)); + + int carry0 = (int)(this.x0 >> 26); + this.x1 = (int)((int)this.x1 + carry0); + this.x0 = (int)((int)this.x0 - (carry0 << 26)); + int carry1 = (int)(this.x1 >> 25); + this.x2 = (int)((int)this.x2 + carry1); + this.x1 = (int)((int)this.x1 - (carry1 << 25)); + int carry2 = (int)(this.x2 >> 26); + this.x3 = (int)((int)this.x3 + carry2); + this.x2 = (int)((int)this.x2 - (carry2 << 26)); + int carry3 = (int)(this.x3 >> 25); + this.x4 = (int)((int)this.x4 + carry3); + this.x3 = (int)((int)this.x3 - (carry3 << 25)); + int carry4 = (int)(this.x4 >> 26); + this.x5 = (int)((int)this.x5 + carry4); + this.x4 = (int)((int)this.x4 - (carry4 << 26)); + int carry5 = (int)(this.x5 >> 25); + this.x6 = (int)((int)this.x6 + carry5); + this.x5 = (int)((int)this.x5 - (carry5 << 25)); + int carry6 = (int)(this.x6 >> 26); + this.x7 = (int)((int)this.x7 + carry6); + this.x6 = (int)((int)this.x6 - (carry6 << 26)); + int carry7 = (int)(this.x7 >> 25); + this.x8 = (int)((int)this.x8 + carry7); + this.x7 = (int)((int)this.x7 - (carry7 << 25)); + int carry8 = (int)(this.x8 >> 26); + this.x9 = (int)((int)this.x9 + carry8); + this.x8 = (int)((int)this.x8 - (carry8 << 26)); + int carry9 = (int)(this.x9 >> 25); + this.x9 = (int)((int)this.x9 - (carry9 << 25)); + + output[ 0] = (byte)(this.x0 >> 0); + output[ 1] = (byte)(this.x0 >> 8); + output[ 2] = (byte)(this.x0 >> 16); + output[ 3] = (byte)((this.x0 >> 24) | (this.x1 << 2)); + output[ 4] = (byte)(this.x1 >> 6); + output[ 5] = (byte)(this.x1 >> 14); + output[ 6] = (byte)((this.x1 >> 22) | (this.x2 << 3)); + output[ 7] = (byte)(this.x2 >> 5); + output[ 8] = (byte)(this.x2 >> 13); + output[ 9] = (byte)((this.x2 >> 21) | (this.x3 << 5)); + output[10] = (byte)(this.x3 >> 3); + output[11] = (byte)(this.x3 >> 11); + output[12] = (byte)((this.x3 >> 19) | (this.x4 << 6)); + output[13] = (byte)(this.x4 >> 2); + output[14] = (byte)(this.x4 >> 10); + output[15] = (byte)(this.x4 >> 18); + output[16] = (byte)(this.x5 >> 0); + output[17] = (byte)(this.x5 >> 8); + output[18] = (byte)(this.x5 >> 16); + output[19] = (byte)((this.x5 >> 24) | (this.x6 << 1)); + output[20] = (byte)(this.x6 >> 7); + output[21] = (byte)(this.x6 >> 15); + output[22] = (byte)((this.x6 >> 23) | (this.x7 << 3)); + output[23] = (byte)(this.x7 >> 5); + output[24] = (byte)(this.x7 >> 13); + output[25] = (byte)((this.x7 >> 21) | (this.x8 << 4)); + output[26] = (byte)(this.x8 >> 4); + output[27] = (byte)(this.x8 >> 12); + output[28] = (byte)((this.x8 >> 20) | (this.x9 << 6)); + output[29] = (byte)(this.x9 >> 2); + output[30] = (byte)(this.x9 >> 10); + output[31] = (byte)(this.x9 >> 18); + } + + /// <summary> + /// Set the field element to `0` + /// </summary> + public static FieldElement Zero() + { + return new FieldElement(); + } + + /// <summary> + /// Set the field element to `1` + /// </summary> + public static FieldElement One() + { + FieldElement result = Zero(); + result.x0 = 1; + return result; + } + + /// <summary> + /// Add two field elements + /// </summary> + public static void Add(ref FieldElement output, ref FieldElement a, ref FieldElement b) + { + output.x0 = a.x0 + b.x0; + output.x1 = a.x1 + b.x1; + output.x2 = a.x2 + b.x2; + output.x3 = a.x3 + b.x3; + output.x4 = a.x4 + b.x4; + output.x5 = a.x5 + b.x5; + output.x6 = a.x6 + b.x6; + output.x7 = a.x7 + b.x7; + output.x8 = a.x8 + b.x8; + output.x9 = a.x9 + b.x9; + } + + /// <summary> + /// Subtract two field elements + /// </summary> + public static void Sub(ref FieldElement output, ref FieldElement a, ref FieldElement b) + { + output.x0 = a.x0 - b.x0; + output.x1 = a.x1 - b.x1; + output.x2 = a.x2 - b.x2; + output.x3 = a.x3 - b.x3; + output.x4 = a.x4 - b.x4; + output.x5 = a.x5 - b.x5; + output.x6 = a.x6 - b.x6; + output.x7 = a.x7 - b.x7; + output.x8 = a.x8 - b.x8; + output.x9 = a.x9 - b.x9; + } + + /// <summary> + /// Multiply two field elements + /// </summary> + public static void Multiply(ref FieldElement output, ref FieldElement a, ref FieldElement b) + { + int b1_19 = 19 * b.x1; + int b2_19 = 19 * b.x2; + int b3_19 = 19 * b.x3; + int b4_19 = 19 * b.x4; + int b5_19 = 19 * b.x5; + int b6_19 = 19 * b.x6; + int b7_19 = 19 * b.x7; + int b8_19 = 19 * b.x8; + int b9_19 = 19 * b.x9; + + int a1_2 = 2 * a.x1; + int a3_2 = 2 * a.x3; + int a5_2 = 2 * a.x5; + int a7_2 = 2 * a.x7; + int a9_2 = 2 * a.x9; + + long a0b0 = (long)a.x0 * (long)b.x0; + long a0b1 = (long)a.x0 * (long)b.x1; + long a0b2 = (long)a.x0 * (long)b.x2; + long a0b3 = (long)a.x0 * (long)b.x3; + long a0b4 = (long)a.x0 * (long)b.x4; + long a0b5 = (long)a.x0 * (long)b.x5; + long a0b6 = (long)a.x0 * (long)b.x6; + long a0b7 = (long)a.x0 * (long)b.x7; + long a0b8 = (long)a.x0 * (long)b.x8; + long a0b9 = (long)a.x0 * (long)b.x9; + long a1b0 = (long)a.x1 * (long)b.x0; + long a1b1_2 = (long)a1_2 * (long)b.x1; + long a1b2 = (long)a.x1 * (long)b.x2; + long a1b3_2 = (long)a1_2 * (long)b.x3; + long a1b4 = (long)a.x1 * (long)b.x4; + long a1b5_2 = (long)a1_2 * (long)b.x5; + long a1b6 = (long)a.x1 * (long)b.x6; + long a1b7_2 = (long)a1_2 * (long)b.x7; + long a1b8 = (long)a.x1 * (long)b.x8; + long a1b9_38 = (long)a1_2 * (long)b9_19; + long a2b0 = (long)a.x2 * (long)b.x0; + long a2b1 = (long)a.x2 * (long)b.x1; + long a2b2 = (long)a.x2 * (long)b.x2; + long a2b3 = (long)a.x2 * (long)b.x3; + long a2b4 = (long)a.x2 * (long)b.x4; + long a2b5 = (long)a.x2 * (long)b.x5; + long a2b6 = (long)a.x2 * (long)b.x6; + long a2b7 = (long)a.x2 * (long)b.x7; + long a2b8_19 = (long)a.x2 * (long)b8_19; + long a2b9_19 = (long)a.x2 * (long)b9_19; + long a3b0 = (long)a.x3 * (long)b.x0; + long a3b1_2 = (long)a3_2 * (long)b.x1; + long a3b2 = (long)a.x3 * (long)b.x2; + long a3b3_2 = (long)a3_2 * (long)b.x3; + long a3b4 = (long)a.x3 * (long)b.x4; + long a3b5_2 = (long)a3_2 * (long)b.x5; + long a3b6 = (long)a.x3 * (long)b.x6; + long a3b7_38 = (long)a3_2 * (long)b7_19; + long a3b8_19 = (long)a.x3 * (long)b8_19; + long a3b9_38 = (long)a3_2 * (long)b9_19; + long a4b0 = (long)a.x4 * (long)b.x0; + long a4b1 = (long)a.x4 * (long)b.x1; + long a4b2 = (long)a.x4 * (long)b.x2; + long a4b3 = (long)a.x4 * (long)b.x3; + long a4b4 = (long)a.x4 * (long)b.x4; + long a4b5 = (long)a.x4 * (long)b.x5; + long a4b6_19 = (long)a.x4 * (long)b6_19; + long a4b7_19 = (long)a.x4 * (long)b7_19; + long a4b8_19 = (long)a.x4 * (long)b8_19; + long a4b9_19 = (long)a.x4 * (long)b9_19; + long a5b0 = (long)a.x5 * (long)b.x0; + long a5b1_2 = (long)a5_2 * (long)b.x1; + long a5b2 = (long)a.x5 * (long)b.x2; + long a5b3_2 = (long)a5_2 * (long)b.x3; + long a5b4 = (long)a.x5 * (long)b.x4; + long a5b5_38 = (long)a5_2 * (long)b5_19; + long a5b6_19 = (long)a.x5 * (long)b6_19; + long a5b7_38 = (long)a5_2 * (long)b7_19; + long a5b8_19 = (long)a.x5 * (long)b8_19; + long a5b9_38 = (long)a5_2 * (long)b9_19; + long a6b0 = (long)a.x6 * (long)b.x0; + long a6b1 = (long)a.x6 * (long)b.x1; + long a6b2 = (long)a.x6 * (long)b.x2; + long a6b3 = (long)a.x6 * (long)b.x3; + long a6b4_19 = (long)a.x6 * (long)b4_19; + long a6b5_19 = (long)a.x6 * (long)b5_19; + long a6b6_19 = (long)a.x6 * (long)b6_19; + long a6b7_19 = (long)a.x6 * (long)b7_19; + long a6b8_19 = (long)a.x6 * (long)b8_19; + long a6b9_19 = (long)a.x6 * (long)b9_19; + long a7b0 = (long)a.x7 * (long)b.x0; + long a7b1_2 = (long)a7_2 * (long)b.x1; + long a7b2 = (long)a.x7 * (long)b.x2; + long a7b3_38 = (long)a7_2 * (long)b3_19; + long a7b4_19 = (long)a.x7 * (long)b4_19; + long a7b5_38 = (long)a7_2 * (long)b5_19; + long a7b6_19 = (long)a.x7 * (long)b6_19; + long a7b7_38 = (long)a7_2 * (long)b7_19; + long a7b8_19 = (long)a.x7 * (long)b8_19; + long a7b9_38 = (long)a7_2 * (long)b9_19; + long a8b0 = (long)a.x8 * (long)b.x0; + long a8b1 = (long)a.x8 * (long)b.x1; + long a8b2_19 = (long)a.x8 * (long)b2_19; + long a8b3_19 = (long)a.x8 * (long)b3_19; + long a8b4_19 = (long)a.x8 * (long)b4_19; + long a8b5_19 = (long)a.x8 * (long)b5_19; + long a8b6_19 = (long)a.x8 * (long)b6_19; + long a8b7_19 = (long)a.x8 * (long)b7_19; + long a8b8_19 = (long)a.x8 * (long)b8_19; + long a8b9_19 = (long)a.x8 * (long)b9_19; + long a9b0 = (long)a.x9 * (long)b.x0; + long a9b1_38 = (long)a9_2 * (long)b1_19; + long a9b2_19 = (long)a.x9 * (long)b2_19; + long a9b3_38 = (long)a9_2 * (long)b3_19; + long a9b4_19 = (long)a.x9 * (long)b4_19; + long a9b5_38 = (long)a9_2 * (long)b5_19; + long a9b6_19 = (long)a.x9 * (long)b6_19; + long a9b7_38 = (long)a9_2 * (long)b7_19; + long a9b8_19 = (long)a.x9 * (long)b8_19; + long a9b9_38 = (long)a9_2 * (long)b9_19; + + long h0 = a0b0 + a1b9_38 + a2b8_19 + a3b7_38 + a4b6_19 + a5b5_38 + a6b4_19 + a7b3_38 + a8b2_19 + a9b1_38; + long h1 = a0b1 + a1b0 + a2b9_19 + a3b8_19 + a4b7_19 + a5b6_19 + a6b5_19 + a7b4_19 + a8b3_19 + a9b2_19; + long h2 = a0b2 + a1b1_2 + a2b0 + a3b9_38 + a4b8_19 + a5b7_38 + a6b6_19 + a7b5_38 + a8b4_19 + a9b3_38; + long h3 = a0b3 + a1b2 + a2b1 + a3b0 + a4b9_19 + a5b8_19 + a6b7_19 + a7b6_19 + a8b5_19 + a9b4_19; + long h4 = a0b4 + a1b3_2 + a2b2 + a3b1_2 + a4b0 + a5b9_38 + a6b8_19 + a7b7_38 + a8b6_19 + a9b5_38; + long h5 = a0b5 + a1b4 + a2b3 + a3b2 + a4b1 + a5b0 + a6b9_19 + a7b8_19 + a8b7_19 + a9b6_19; + long h6 = a0b6 + a1b5_2 + a2b4 + a3b3_2 + a4b2 + a5b1_2 + a6b0 + a7b9_38 + a8b8_19 + a9b7_38; + long h7 = a0b7 + a1b6 + a2b5 + a3b4 + a4b3 + a5b2 + a6b1 + a7b0 + a8b9_19 + a9b8_19; + long h8 = a0b8 + a1b7_2 + a2b6 + a3b5_2 + a4b4 + a5b3_2 + a6b2 + a7b1_2 + a8b0 + a9b9_38; + long h9 = a0b9 + a1b8 + a2b7 + a3b6 + a4b5 + a5b4 + a6b3 + a7b2 + a8b1 + a9b0; + + long carry0 = (h0 + (1L << 25)) >> 26; + h1 += carry0; + h0 -= carry0 << 26; + long carry4 = (h4 + (1L << 25)) >> 26; + h5 += carry4; + h4 -= carry4 << 26; + + long carry1 = (h1 + (1L << 24)) >> 25; + h2 += carry1; + h1 -= carry1 << 25; + long carry5 = (h5 + (1L << 24)) >> 25; + h6 += carry5; + h5 -= carry5 << 25; + + long carry2 = (h2 + (1L << 25)) >> 26; + h3 += carry2; + h2 -= carry2 << 26; + long carry6 = (h6 + (1L << 25)) >> 26; + h7 += carry6; + h6 -= carry6 << 26; + + long carry3 = (h3 + (1L << 24)) >> 25; + h4 += carry3; + h3 -= carry3 << 25; + long carry7 = (h7 + (1L << 24)) >> 25; + h8 += carry7; + h7 -= carry7 << 25; + + carry4 = (h4 + (1L << 25)) >> 26; + h5 += carry4; + h4 -= carry4 << 26; + long carry8 = (h8 + (1L << 25)) >> 26; + h9 += carry8; + h8 -= carry8 << 26; + + long carry9 = (h9 + (1L << 24)) >> 25; + h0 += carry9 * 19; + h9 -= carry9 << 25; + + carry0 = (h0 + (1L << 25)) >> 26; + h1 += carry0; + h0 -= carry0 << 26; + + output.x0 = (int)h0; + output.x1 = (int)h1; + output.x2 = (int)h2; + output.x3 = (int)h3; + output.x4 = (int)h4; + output.x5 = (int)h5; + output.x6 = (int)h6; + output.x7 = (int)h7; + output.x8 = (int)h8; + output.x9 = (int)h9; + } + + /// <summary> + /// Square a field element + /// </summary> + public static void Square(ref FieldElement output, ref FieldElement a) + { + int a0_2 = a.x0 * 2; + int a1_2 = a.x1 * 2; + int a2_2 = a.x2 * 2; + int a3_2 = a.x3 * 2; + int a4_2 = a.x4 * 2; + int a5_2 = a.x5 * 2; + int a6_2 = a.x6 * 2; + int a7_2 = a.x7 * 2; + + int a5_38 = a.x5 * 38; + int a6_19 = a.x6 * 19; + int a7_38 = a.x7 * 38; + int a8_19 = a.x8 * 19; + int a9_38 = a.x9 * 38; + + long a0a0 = (long)a.x0 * (long)a.x0; + long a0a1_2 = (long)a0_2 * (long)a.x1; + long a0a2_2 = (long)a0_2 * (long)a.x2; + long a0a3_2 = (long)a0_2 * (long)a.x3; + long a0a4_2 = (long)a0_2 * (long)a.x4; + long a0a5_2 = (long)a0_2 * (long)a.x5; + long a0a6_2 = (long)a0_2 * (long)a.x6; + long a0a7_2 = (long)a0_2 * (long)a.x7; + long a0a8_2 = (long)a0_2 * (long)a.x8; + long a0a9_2 = (long)a0_2 * (long)a.x9; + long a1a1_2 = (long)a1_2 * (long)a.x1; + long a1a2_2 = (long)a1_2 * (long)a.x2; + long a1a3_4 = (long)a1_2 * (long)a3_2; + long a1a4_2 = (long)a1_2 * (long)a.x4; + long a1a5_4 = (long)a1_2 * (long)a5_2; + long a1a6_2 = (long)a1_2 * (long)a.x6; + long a1a7_4 = (long)a1_2 * (long)a7_2; + long a1a8_2 = (long)a1_2 * (long)a.x8; + long a1a9_76 = (long)a1_2 * (long)a9_38; + long a2a2 = (long)a.x2 * (long)a.x2; + long a2a3_2 = (long)a2_2 * (long)a.x3; + long a2a4_2 = (long)a2_2 * (long)a.x4; + long a2a5_2 = (long)a2_2 * (long)a.x5; + long a2a6_2 = (long)a2_2 * (long)a.x6; + long a2a7_2 = (long)a2_2 * (long)a.x7; + long a2a8_38 = (long)a2_2 * (long)a8_19; + long a2a9_38 = (long)a.x2 * (long)a9_38; + long a3a3_2 = (long)a3_2 * (long)a.x3; + long a3a4_2 = (long)a3_2 * (long)a.x4; + long a3a5_4 = (long)a3_2 * (long)a5_2; + long a3a6_2 = (long)a3_2 * (long)a.x6; + long a3a7_76 = (long)a3_2 * (long)a7_38; + long a3a8_38 = (long)a3_2 * (long)a8_19; + long a3a9_76 = (long)a3_2 * (long)a9_38; + long a4a4 = (long)a.x4 * (long)a.x4; + long a4a5_2 = (long)a4_2 * (long)a.x5; + long a4a6_38 = (long)a4_2 * (long)a6_19; + long a4a7_38 = (long)a.x4 * (long)a7_38; + long a4a8_38 = (long)a4_2 * (long)a8_19; + long a4a9_38 = (long)a.x4 * (long)a9_38; + long a5a5_38 = (long)a.x5 * (long)a5_38; + long a5a6_38 = (long)a5_2 * (long)a6_19; + long a5a7_76 = (long)a5_2 * (long)a7_38; + long a5a8_38 = (long)a5_2 * (long)a8_19; + long a5a9_76 = (long)a5_2 * (long)a9_38; + long a6a6_19 = (long)a.x6 * (long)a6_19; + long a6a7_38 = (long)a.x6 * (long)a7_38; + long a6a8_38 = (long)a6_2 * (long)a8_19; + long a6a9_38 = (long)a.x6 * (long)a9_38; + long a7a7_38 = (long)a.x7 * (long)a7_38; + long a7a8_38 = (long)a7_2 * (long)a8_19; + long a7a9_76 = (long)a7_2 * (long)a9_38; + long a8a8_19 = (long)a.x8 * (long)a8_19; + long a8a9_38 = (long)a.x8 * (long)a9_38; + long a9a9_38 = (long)a.x9 * (long)a9_38; + + long h0 = a0a0 + a1a9_76 + a2a8_38 + a3a7_76 + a4a6_38 + a5a5_38; + long h1 = a0a1_2 + a2a9_38 + a3a8_38 + a4a7_38 + a5a6_38; + long h2 = a0a2_2 + a1a1_2 + a3a9_76 + a4a8_38 + a5a7_76 + a6a6_19; + long h3 = a0a3_2 + a1a2_2 + a4a9_38 + a5a8_38 + a6a7_38; + long h4 = a0a4_2 + a1a3_4 + a2a2 + a5a9_76 + a6a8_38 + a7a7_38; + long h5 = a0a5_2 + a1a4_2 + a2a3_2 + a6a9_38 + a7a8_38; + long h6 = a0a6_2 + a1a5_4 + a2a4_2 + a3a3_2 + a7a9_76 + a8a8_19; + long h7 = a0a7_2 + a1a6_2 + a2a5_2 + a3a4_2 + a8a9_38; + long h8 = a0a8_2 + a1a7_4 + a2a6_2 + a3a5_4 + a4a4 + a9a9_38; + long h9 = a0a9_2 + a1a8_2 + a2a7_2 + a3a6_2 + a4a5_2; + + long carry0 = (h0 + (1L << 25)) >> 26; + h1 += carry0; + h0 -= carry0 << 26; + long carry4 = (h4 + (1L << 25)) >> 26; + h5 += carry4; + h4 -= carry4 << 26; + + long carry1 = (h1 + (1L << 24)) >> 25; + h2 += carry1; + h1 -= carry1 << 25; + long carry5 = (h5 + (1L << 24)) >> 25; + h6 += carry5; + h5 -= carry5 << 25; + + long carry2 = (h2 + (1L << 25)) >> 26; + h3 += carry2; + h2 -= carry2 << 26; + long carry6 = (h6 + (1L << 25)) >> 26; + h7 += carry6; + h6 -= carry6 << 26; + + long carry3 = (h3 + (1L << 24)) >> 25; + h4 += carry3; + h3 -= carry3 << 25; + long carry7 = (h7 + (1L << 24)) >> 25; + h8 += carry7; + h7 -= carry7 << 25; + + carry4 = (h4 + (1L << 25)) >> 26; + h5 += carry4; + h4 -= carry4 << 26; + long carry8 = (h8 + (1L << 25)) >> 26; + h9 += carry8; + h8 -= carry8 << 26; + + long carry9 = (h9 + (1L << 24)) >> 25; + h0 += carry9 * 19; + h9 -= carry9 << 25; + + carry0 = (h0 + (1L << 25)) >> 26; + h1 += carry0; + h0 -= carry0 << 26; + + output.x0 = (int)h0; + output.x1 = (int)h1; + output.x2 = (int)h2; + output.x3 = (int)h3; + output.x4 = (int)h4; + output.x5 = (int)h5; + output.x6 = (int)h6; + output.x7 = (int)h7; + output.x8 = (int)h8; + output.x9 = (int)h9; + } + + /// <summary> + /// Multiplay a field element by 121666 + /// </summary> + public static void Multiply121666(ref FieldElement output, ref FieldElement a) + { + long h0 = (long)a.x0 * 121666L; + long h1 = (long)a.x1 * 121666L; + long h2 = (long)a.x2 * 121666L; + long h3 = (long)a.x3 * 121666L; + long h4 = (long)a.x4 * 121666L; + long h5 = (long)a.x5 * 121666L; + long h6 = (long)a.x6 * 121666L; + long h7 = (long)a.x7 * 121666L; + long h8 = (long)a.x8 * 121666L; + long h9 = (long)a.x9 * 121666L; + + long carry9 = (h9 + (1L<<24)) >> 25; + h0 += carry9 * 19; + h9 -= carry9 << 25; + long carry1 = (h1 + (1L<<24)) >> 25; + h2 += carry1; + h1 -= carry1 << 25; + long carry3 = (h3 + (1L<<24)) >> 25; + h4 += carry3; + h3 -= carry3 << 25; + long carry5 = (h5 + (1L<<24)) >> 25; + h6 += carry5; + h5 -= carry5 << 25; + long carry7 = (h7 + (1L<<24)) >> 25; + h8 += carry7; + h7 -= carry7 << 25; + + long carry0 = (h0 + (1L << 25)) >> 26; + h1 += carry0; + h0 -= carry0 << 26; + long carry2 = (h2 + (1L << 25)) >> 26; + h3 += carry2; + h2 -= carry2 << 26; + long carry4 = (h4 + (1L << 25)) >> 26; + h5 += carry4; + h4 -= carry4 << 26; + long carry6 = (h6 + (1L << 25)) >> 26; + h7 += carry6; + h6 -= carry6 << 26; + long carry8 = (h8 + (1L << 25)) >> 26; + h9 += carry8; + h8 -= carry8 << 26; + + output.x0 = (int)h0; + output.x1 = (int)h1; + output.x2 = (int)h2; + output.x3 = (int)h3; + output.x4 = (int)h4; + output.x5 = (int)h5; + output.x6 = (int)h6; + output.x7 = (int)h7; + output.x8 = (int)h8; + output.x9 = (int)h9; + } + + /// <summary> + /// Invert a field element + /// </summary> + public static void Invert(ref FieldElement output, ref FieldElement a) + { + FieldElement t0 = new FieldElement(); + Square(ref t0, ref a); + + FieldElement t1 = new FieldElement(); + Square(ref t1, ref t0); + Square(ref t1, ref t1); + + FieldElement t2= new FieldElement(); + Multiply(ref t1, ref a, ref t1); + Multiply(ref t0, ref t0, ref t1); + Square(ref t2, ref t0); + //Square(ref t2, ref t2); + + Multiply(ref t1, ref t1, ref t2); + Square(ref t2, ref t1); + for (int ii = 1; ii < 5; ++ii) + { + Square(ref t2, ref t2); + } + + Multiply(ref t1, ref t2, ref t1); + Square(ref t2, ref t1); + for (int ii = 1; ii < 10; ++ii) + { + Square(ref t2, ref t2); + } + + FieldElement t3 = new FieldElement(); + Multiply(ref t2, ref t2, ref t1); + Square(ref t3, ref t2); + for (int ii = 1; ii < 20; ++ii) + { + Square(ref t3, ref t3); + } + + Multiply(ref t2, ref t3, ref t2); + Square(ref t2, ref t2); + for (int ii = 1; ii < 10; ++ii) + { + Square(ref t2, ref t2); + } + + Multiply(ref t1, ref t2, ref t1); + Square(ref t2, ref t1); + for (int ii = 1; ii < 50; ++ii) + { + Square(ref t2, ref t2); + } + + Multiply(ref t2, ref t2, ref t1); + Square(ref t3, ref t2); + for (int ii = 1; ii < 100; ++ii) + { + Square(ref t3, ref t3); + } + + Multiply(ref t2, ref t3, ref t2); + Square(ref t2, ref t2); + for (int ii = 1; ii < 50; ++ii) + { + Square(ref t2, ref t2); + } + + Multiply(ref t1, ref t2, ref t1); + Square(ref t1, ref t1); + for (int ii = 1; ii < 5; ++ii) + { + Square(ref t1, ref t1); + } + + Multiply(ref output, ref t1, ref t0); + } + + /// <summary> + /// Swaps `a` and `b` if `swap` is 1 + /// </summary> + public static void ConditionalSwap(ref FieldElement a, ref FieldElement b, int swap) + { + Debug.Assert(swap == 0 || swap == 1); + swap = -swap; + + FieldElement temp = new FieldElement + { + x0 = swap & (a.x0 ^ b.x0), + x1 = swap & (a.x1 ^ b.x1), + x2 = swap & (a.x2 ^ b.x2), + x3 = swap & (a.x3 ^ b.x3), + x4 = swap & (a.x4 ^ b.x4), + x5 = swap & (a.x5 ^ b.x5), + x6 = swap & (a.x6 ^ b.x6), + x7 = swap & (a.x7 ^ b.x7), + x8 = swap & (a.x8 ^ b.x8), + x9 = swap & (a.x9 ^ b.x9), + }; + + a.x0 ^= temp.x0; + a.x1 ^= temp.x1; + a.x2 ^= temp.x2; + a.x3 ^= temp.x3; + a.x4 ^= temp.x4; + a.x5 ^= temp.x5; + a.x6 ^= temp.x6; + a.x7 ^= temp.x7; + a.x8 ^= temp.x8; + a.x9 ^= temp.x9; + + b.x0 ^= temp.x0; + b.x1 ^= temp.x1; + b.x2 ^= temp.x2; + b.x3 ^= temp.x3; + b.x4 ^= temp.x4; + b.x5 ^= temp.x5; + b.x6 ^= temp.x6; + b.x7 ^= temp.x7; + b.x8 ^= temp.x8; + b.x9 ^= temp.x9; + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs new file mode 100644 index 0000000..35609fc --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + public struct DataReceivedEventArgs + { + public readonly Connection Sender; + + /// <summary> + /// The bytes received from the client. + /// </summary> + public readonly MessageReader Message; + + /// <summary> + /// The <see cref="SendOption"/> the data was sent with. + /// </summary> + public readonly SendOption SendOption; + + public DataReceivedEventArgs(Connection sender, MessageReader msg, SendOption sendOption) + { + this.Sender = sender; + this.Message = msg; + this.SendOption = sendOption; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs new file mode 100644 index 0000000..a7fb05c --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs @@ -0,0 +1,24 @@ +using System; + +namespace Hazel +{ + public class DisconnectedEventArgs : EventArgs + { + /// <summary> + /// Optional disconnect reason. May be null. + /// </summary> + public readonly string Reason; + + /// <summary> + /// Optional data sent with a disconnect message. May be null. + /// You must not recycle this. If you need the message outside of a callback, you should copy it. + /// </summary> + public readonly MessageReader Message; + + public DisconnectedEventArgs(string reason, MessageReader message) + { + this.Reason = reason; + this.Message = message; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs new file mode 100644 index 0000000..65df39e --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs @@ -0,0 +1,147 @@ +using Hazel.Crypto; +using System; +using System.Diagnostics; + +namespace Hazel.Dtls +{ + /// <summary> + /// *_AES_128_GCM_* cipher suite + /// </summary> + public class Aes128GcmRecordProtection: IRecordProtection + { + private const int ImplicitNonceSize = 4; + private const int ExplicitNonceSize = 8; + + private readonly Aes128Gcm serverWriteCipher; + private readonly Aes128Gcm clientWriteCipher; + + private readonly ByteSpan serverWriteIV; + private readonly ByteSpan clientWriteIV; + + /// <summary> + /// Create a new instance of the AES128_GCM record protection + /// </summary> + /// <param name="masterSecret">Shared secret</param> + /// <param name="serverRandom">Server random data</param> + /// <param name="clientRandom">Client random data</param> + public Aes128GcmRecordProtection(ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom) + { + ByteSpan combinedRandom = new byte[serverRandom.Length + clientRandom.Length]; + serverRandom.CopyTo(combinedRandom); + clientRandom.CopyTo(combinedRandom.Slice(serverRandom.Length)); + + // Expand master_secret to encryption keys + const int ExpandedSize = 0 + + 0 // mac_key_length + + 0 // mac_key_length + + Aes128Gcm.KeySize // enc_key_length + + Aes128Gcm.KeySize // enc_key_length + + ImplicitNonceSize // fixed_iv_length + + ImplicitNonceSize // fixed_iv_length + ; + + ByteSpan expandedKey = new byte[ExpandedSize]; + PrfSha256.ExpandSecret(expandedKey, masterSecret, PrfLabel.KEY_EXPANSION, combinedRandom); + + ByteSpan clientWriteKey = expandedKey.Slice(0, Aes128Gcm.KeySize); + ByteSpan serverWriteKey = expandedKey.Slice(Aes128Gcm.KeySize, Aes128Gcm.KeySize); + this.clientWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize, ImplicitNonceSize); + this.serverWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize + ImplicitNonceSize, ImplicitNonceSize); + + this.serverWriteCipher = new Aes128Gcm(serverWriteKey); + this.clientWriteCipher = new Aes128Gcm(clientWriteKey); + } + + /// <inheritdoc /> + public void Dispose() + { + this.serverWriteCipher.Dispose(); + this.clientWriteCipher.Dispose(); + } + + /// <inheritdoc /> + private static int GetEncryptedSizeImpl(int dataSize) + { + return dataSize + Aes128Gcm.CiphertextOverhead; + } + + /// <inheritdoc /> + public int GetEncryptedSize(int dataSize) + { + return GetEncryptedSizeImpl(dataSize); + } + + private static int GetDecryptedSizeImpl(int dataSize) + { + return dataSize - Aes128Gcm.CiphertextOverhead; + } + + /// <inheritdoc /> + public int GetDecryptedSize(int dataSize) + { + return GetDecryptedSizeImpl(dataSize); + } + + /// <inheritdoc /> + public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + EncryptPlaintext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV); + } + + /// <inheritdoc /> + public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + EncryptPlaintext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV); + } + + private static void EncryptPlaintext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV) + { + Debug.Assert(output.Length >= GetEncryptedSizeImpl(input.Length)); + + // Build GCM nonce (authenticated data) + ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize]; + writeIV.CopyTo(nonce); + nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize); + nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2); + + // Serialize record as additional data + Record plaintextRecord = record; + plaintextRecord.Length = (ushort)input.Length; + ByteSpan associatedData = new byte[Record.Size]; + plaintextRecord.Encode(associatedData); + + cipher.Seal(output, nonce, input, associatedData); + } + + /// <inheritdoc /> + public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record) + { + return DecryptCiphertext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV); + } + + /// <inheritdoc /> + public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record) + { + return DecryptCiphertext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV); + } + + private static bool DecryptCiphertext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV) + { + Debug.Assert(output.Length >= GetDecryptedSizeImpl(input.Length)); + + // Build GCM nonce (authenticated data) + ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize]; + writeIV.CopyTo(nonce); + nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize); + nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2); + + // Serialize record as additional data + Record plaintextRecord = record; + plaintextRecord.Length = (ushort)GetDecryptedSizeImpl(input.Length); + ByteSpan associatedData = new byte[Record.Size]; + plaintextRecord.Encode(associatedData); + + return cipher.Open(output, nonce, input, associatedData); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs new file mode 100644 index 0000000..61f41d3 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs @@ -0,0 +1,1424 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Hazel.Udp.FewerThreads; +using Hazel.Crypto; + +namespace Hazel.Dtls +{ + /// <summary> + /// Listens for new UDP-DTLS connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class DtlsConnectionListener : ThreadLimitedUdpConnectionListener + { + private const int MaxCertFragmentSizeV0 = 1200; + + // Min MTU - UDP+IP header - 1 (for good measure. :)) + private const int MaxCertFragmentSizeV1 = 576 - 32 - 1; + + /// <summary> + /// Current state of handshake sequence + /// </summary> + enum HandshakeState + { + ExpectingHello, + ExpectingClientKeyExchange, + ExpectingChangeCipherSpec, + ExpectingFinish + } + + /// <summary> + /// State to manage the current epoch `N` + /// </summary> + struct CurrentEpoch + { + public ulong NextOutgoingSequence; + + public ulong NextExpectedSequence; + public ulong PreviousSequenceWindowBitmask; + + public IRecordProtection RecordProtection; + public IRecordProtection PreviousRecordProtection; + + // Need to keep these around so we can re-transmit our + // last handshake record flight + public ByteSpan ExpectedClientFinishedVerification; + public ByteSpan ServerFinishedVerification; + public ulong NextOutgoingSequenceForPreviousEpoch; + } + + /// <summary> + /// State to manage the transition from the current + /// epoch `N` to epoch `N+1` + /// </summary> + struct NextEpoch + { + public ushort Epoch; + + public HandshakeState State; + public CipherSuite SelectedCipherSuite; + + public ulong NextOutgoingSequence; + + public IHandshakeCipherSuite Handshake; + public IRecordProtection RecordProtection; + + public ByteSpan ClientRandom; + public ByteSpan ServerRandom; + + public Sha256Stream VerificationStream; + + public ByteSpan ClientVerification; + public ByteSpan ServerVerification; + + } + + /// <summary> + /// Per-peer state + /// </summary> + sealed class PeerData : IDisposable + { + public ushort Epoch; + public bool CanHandleApplicationData; + + public HazelDtlsSessionInfo Session; + + public CurrentEpoch CurrentEpoch; + public NextEpoch NextEpoch; + + public ConnectionId ConnectionId; + + public readonly List<ByteSpan> QueuedApplicationDataMessage = new List<ByteSpan>(); + public readonly ConcurrentBag<MessageReader> ApplicationData = new ConcurrentBag<MessageReader>(); + public readonly ProtocolVersion ProtocolVersion; + + public DateTime StartOfNegotiation; + + public PeerData(ConnectionId connectionId, ulong nextExpectedSequenceNumber, ProtocolVersion protocolVersion) + { + ByteSpan block = new byte[2 * Finished.Size]; + this.CurrentEpoch.ServerFinishedVerification = block.Slice(0, Finished.Size); + this.CurrentEpoch.ExpectedClientFinishedVerification = block.Slice(Finished.Size, Finished.Size); + this.ProtocolVersion = protocolVersion; + + ResetPeer(connectionId, nextExpectedSequenceNumber); + } + + public void ResetPeer(ConnectionId connectionId, ulong nextExpectedSequenceNumber) + { + Dispose(); + + this.Epoch = 0; + this.CanHandleApplicationData = false; + this.QueuedApplicationDataMessage.Clear(); + + this.CurrentEpoch.NextOutgoingSequence = 2; // Account for our ClientHelloVerify + this.CurrentEpoch.NextExpectedSequence = nextExpectedSequenceNumber; + this.CurrentEpoch.PreviousSequenceWindowBitmask = 0; + this.CurrentEpoch.RecordProtection = NullRecordProtection.Instance; + this.CurrentEpoch.PreviousRecordProtection = null; + this.CurrentEpoch.ServerFinishedVerification.SecureClear(); + this.CurrentEpoch.ExpectedClientFinishedVerification.SecureClear(); + + this.NextEpoch.State = HandshakeState.ExpectingHello; + this.NextEpoch.RecordProtection = null; + this.NextEpoch.Handshake = null; + this.NextEpoch.ClientRandom = new byte[Random.Size]; + this.NextEpoch.ServerRandom = new byte[Random.Size]; + this.NextEpoch.VerificationStream = new Sha256Stream(); + this.NextEpoch.ClientVerification = new byte[Finished.Size]; + this.NextEpoch.ServerVerification = new byte[Finished.Size]; + + this.ConnectionId = connectionId; + + this.StartOfNegotiation = DateTime.UtcNow; + } + + public void Dispose() + { + this.CurrentEpoch.RecordProtection?.Dispose(); + this.CurrentEpoch.PreviousRecordProtection?.Dispose(); + this.NextEpoch.RecordProtection?.Dispose(); + this.NextEpoch.Handshake?.Dispose(); + this.NextEpoch.VerificationStream?.Dispose(); + + while (this.ApplicationData.TryTake(out var msg)) + { + try + { + msg.Recycle(); + } + catch { } + } + } + } + + private RandomNumberGenerator random; + + // Private key component of certificate's public key + private ByteSpan encodedCertificate; + private RSA certificatePrivateKey; + + // HMAC key to validate ClientHello cookie + private ThreadedHmacHelper hmacHelper; + private HMAC CurrentCookieHmac { + get + { + return hmacHelper.GetCurrentCookieHmacsForThread(); + } + } + private HMAC PreviousCookieHmac + { + get + { + return hmacHelper.GetPreviousCookieHmacsForThread(); + } + } + + private ConcurrentStack<ConnectionId> staleConnections = new ConcurrentStack<ConnectionId>(); + private readonly ConcurrentDictionary<IPEndPoint, PeerData> existingPeers = new ConcurrentDictionary<IPEndPoint, PeerData>(); + public int PeerCount => this.existingPeers.Count; + + // TODO: Move these into an DtlsErrorStatistics class + public int NonPeerNonHelloPacketsDropped; + public int NonVerifiedFinishedHandshake; + public int NonPeerVerifyHelloRequests; + public int PeerVerifyHelloRequests; + + private int connectionSerial_unsafe = 0; + + private Timer staleConnectionUpkeep; + + /// <summary> + /// Create a new instance of the DTLS listener + /// </summary> + /// <param name="numWorkers"></param> + /// <param name="endPoint"></param> + /// <param name="logger"></param> + /// <param name="ipMode"></param> + public DtlsConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4) + : base(numWorkers, endPoint, logger, ipMode) + { + this.random = RandomNumberGenerator.Create(); + + this.staleConnectionUpkeep = new Timer(this.HandleStaleConnections, null, 2500, 1000); + this.hmacHelper = new ThreadedHmacHelper(logger); + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + this.staleConnectionUpkeep.Dispose(); + + this.random?.Dispose(); + this.random = null; + + this.hmacHelper?.Dispose(); + this.hmacHelper = null; + + foreach (var pair in this.existingPeers) + { + pair.Value.Dispose(); + } + this.existingPeers.Clear(); + } + + /// <summary> + /// Set the certificate key pair for the listener + /// </summary> + /// <param name="certificate">Certificate for the server</param> + public void SetCertificate(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new ArgumentException("Certificate must have a private key attached", nameof(certificate)); + } + + RSA privateKey = certificate.GetRSAPrivateKey(); + if (privateKey == null) + { + throw new ArgumentException("Certificate must be signed by an RSA key", nameof(certificate)); + } + + this.certificatePrivateKey?.Dispose(); + this.certificatePrivateKey = privateKey; + + this.encodedCertificate = Certificate.Encode(certificate); + } + + /// <summary> + /// Handle an incoming datagram from the network. + /// + /// This is primarily a wrapper around ProcessIncomingMessage + /// to ensure `reader.Recycle()` is always called + /// </summary> + protected override void ReadCallback(MessageReader reader, IPEndPoint peerAddress, ConnectionId connectionId) + { + try + { + ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining); + this.ProcessIncomingMessage(message, peerAddress); + } + finally + { + reader.Recycle(); + } + } + + /// <summary> + /// Handle an incoming datagram from the network + /// </summary> + private void ProcessIncomingMessage(ByteSpan message, IPEndPoint peerAddress) + { + PeerData peer = null; + if (!this.existingPeers.TryGetValue(peerAddress, out peer)) + { + lock (this.existingPeers) + { + if (!this.existingPeers.TryGetValue(peerAddress, out peer)) + { + HandleNonPeerRecord(message, peerAddress); + return; + } + } + } + + ConnectionId peerConnectionId; + + lock (peer) + { + peerConnectionId = peer.ConnectionId; + + // Each incoming packet may contain multiple DTLS + // records + while (message.Length > 0) + { + Record record; + if (!Record.Parse(out record, peer.ProtocolVersion, message)) + { + this.Logger.WriteError($"Dropping malformed record from `{peerAddress}`"); + return; + } + message = message.Slice(Record.Size); + + if (message.Length < record.Length) + { + this.Logger.WriteError($"Dropping malformed record from `{peerAddress}` Length({record.Length}) AvailableBytes({message.Length})"); + return; + } + + ByteSpan recordPayload = message.Slice(0, record.Length); + message = message.Slice(record.Length); + + // Early-out and drop ApplicationData records + if (record.ContentType == ContentType.ApplicationData && !peer.CanHandleApplicationData) + { + this.Logger.WriteInfo($"Dropping ApplicationData record from `{peerAddress}` Cannot process yet"); + continue; + } + + // Drop records from a different epoch + if (record.Epoch != peer.Epoch) + { + // Handle existing client negotiating a new connection + if (record.Epoch == 0 && record.ContentType == ContentType.Handshake) + { + ByteSpan handshakePayload = recordPayload; + + Handshake handshake; + if (!Handshake.Parse(out handshake, recordPayload)) + { + this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`"); + continue; + } + handshakePayload = handshakePayload.Slice(Handshake.Size); + + if (handshake.FragmentOffset != 0 || handshake.Length != handshake.FragmentLength) + { + this.Logger.WriteError($"Dropping fragmented re-negotiation Handshake from `{peerAddress}`"); + continue; + } + else if (handshake.MessageType != HandshakeType.ClientHello) + { + this.Logger.WriteVerbose($"Dropping non-ClientHello re-negotiation Handshake from `{peerAddress}`"); + continue; + } + else if (handshakePayload.Length < handshake.Length) + { + this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`: Length({handshake.Length}) AvailableBytes({handshakePayload.Length})"); + } + + if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, recordPayload, handshakePayload)) + { + return; + } + continue; + } + + this.Logger.WriteVerbose($"Dropping bad-epoch record from `{peerAddress}` RecordEpoch({record.Epoch}) CurrentEpoch({peer.Epoch})"); + continue; + } + + // Prevent replay attacks by dropping records + // we've already processed + int windowIndex = (int)(peer.CurrentEpoch.NextExpectedSequence - record.SequenceNumber - 1); + ulong windowMask = 1ul << windowIndex; + if (record.SequenceNumber < peer.CurrentEpoch.NextExpectedSequence) + { + if (windowIndex >= 64) + { + this.Logger.WriteInfo($"Dropping too-old record from `{peerAddress}` Sequence({record.SequenceNumber}) Expected({peer.CurrentEpoch.NextExpectedSequence})"); + continue; + } + + if ((peer.CurrentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0) + { + this.Logger.WriteInfo($"Dropping duplicate record from `{peerAddress}`"); + continue; + } + } + + // Validate record authenticity + int decryptedSize = peer.CurrentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length); + if (decryptedSize < 0) + { + this.Logger.WriteInfo($"Dropping malformed record: Length {recordPayload.Length} Decrypted length: {decryptedSize}"); + continue; + } + + ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize); + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + if (!peer.CurrentEpoch.RecordProtection.DecryptCiphertextFromClient(decryptedPayload, recordPayload, ref record)) + { + this.Logger.WriteVerbose($"Dropping non-authentic {record.ContentType} record from `{peerAddress}`"); + return; + } + + recordPayload = decryptedPayload; + + // Update our squence number bookeeping + if (record.SequenceNumber >= peer.CurrentEpoch.NextExpectedSequence) + { + int windowShift = (int)(record.SequenceNumber + 1 - peer.CurrentEpoch.NextExpectedSequence); + peer.CurrentEpoch.PreviousSequenceWindowBitmask <<= windowShift; + peer.CurrentEpoch.NextExpectedSequence = record.SequenceNumber + 1; + } + else + { + peer.CurrentEpoch.PreviousSequenceWindowBitmask |= windowMask; + } + + // This is handy for debugging, but too verbose even for verbose. + // this.Logger.WriteVerbose($"Record type {record.ContentType} ({peer.NextEpoch.State})"); + switch (record.ContentType) + { + case ContentType.ChangeCipherSpec: + if (peer.NextEpoch.State != HandshakeState.ExpectingChangeCipherSpec) + { + this.Logger.WriteError($"Dropping unexpected ChangeChiperSpec record from `{peerAddress}` State({peer.NextEpoch.State})"); + break; + } + else if (peer.NextEpoch.RecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?"); + + this.Logger.WriteError($"Dropping ChangeCipherSpec message from `{peerAddress}`: No pending record protection"); + break; + } + + if (!ChangeCipherSpec.Parse(recordPayload)) + { + this.Logger.WriteError($"Dropping malformed ChangeCipherSpec message from `{peerAddress}`"); + break; + } + + // Migrate to the next epoch + peer.Epoch = peer.NextEpoch.Epoch; + peer.CanHandleApplicationData = false; // Need a Finished message + peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch = peer.CurrentEpoch.NextOutgoingSequence; + peer.CurrentEpoch.PreviousRecordProtection?.Dispose(); + peer.CurrentEpoch.PreviousRecordProtection = peer.CurrentEpoch.RecordProtection; + peer.CurrentEpoch.RecordProtection = peer.NextEpoch.RecordProtection; + peer.CurrentEpoch.NextOutgoingSequence = 1; + peer.CurrentEpoch.NextExpectedSequence = 1; + peer.CurrentEpoch.PreviousSequenceWindowBitmask = 0; + peer.NextEpoch.ClientVerification.CopyTo(peer.CurrentEpoch.ExpectedClientFinishedVerification); + peer.NextEpoch.ServerVerification.CopyTo(peer.CurrentEpoch.ServerFinishedVerification); + + peer.NextEpoch.State = HandshakeState.ExpectingHello; + peer.NextEpoch.Handshake?.Dispose(); + peer.NextEpoch.Handshake = null; + peer.NextEpoch.NextOutgoingSequence = 1; + peer.NextEpoch.RecordProtection = null; + peer.NextEpoch.VerificationStream.Reset(); + peer.NextEpoch.ClientVerification.SecureClear(); + peer.NextEpoch.ServerVerification.SecureClear(); + break; + + case ContentType.Alert: + this.Logger.WriteError($"Dropping unsupported Alert record from `{peerAddress}`"); + break; + + case ContentType.Handshake: + if (!ProcessHandshake(peer, peerAddress, ref record, recordPayload)) + { + return; + } + break; + + case ContentType.ApplicationData: + // Forward data to the application + MessageReader reader = MessageReader.GetSized(recordPayload.Length); + reader.Length = recordPayload.Length; + recordPayload.CopyTo(reader.Buffer); + + peer.ApplicationData.Add(reader); + break; + } + } + } + + // The peer lock must be exited before leaving the DtlsConnectionListener context to prevent deadlocks + // because ApplicationData processing may reenter this context + while (peer.ApplicationData.TryTake(out var appMsg)) + { + base.ReadCallback(appMsg, peerAddress, peerConnectionId); + } + } + + /// <summary> + /// Process an incoming Handshake protocol message + /// </summary> + /// <param name="peer">Originating peer</param> + /// <param name="peerAddress">Peer's network address</param> + /// <param name="record">Parent record</param> + /// <param name="message">Record payload</param> + /// <returns> + /// True if further processing of the underlying datagram + /// should be continues. Otherwise, false. + /// </returns> + private bool ProcessHandshake(PeerData peer, IPEndPoint peerAddress, ref Record record, ByteSpan message) + { + // Each record may have multiple handshake payloads + while (message.Length > 0) + { + ByteSpan originalMessage = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`"); + return false; + } + message = message.Slice(Handshake.Size); + + if (message.Length < handshake.Length) + { + this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`"); + return false; + } + + ByteSpan payload = message.Slice(0, (int)message.Length); + message = message.Slice((int)handshake.Length); + originalMessage = originalMessage.Slice(0, Handshake.Size + (int)handshake.Length); + + // We do not support fragmented handshake messages + // from the client + if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length) + { + this.Logger.WriteError($"Dropping fragmented Handshake message from `{peerAddress}` Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})"); + continue; + } + + ByteSpan packet; + ByteSpan writer; + +#if DEBUG + this.Logger.WriteVerbose($"Received handshake {handshake.MessageType} ({peer.NextEpoch.State})"); +#endif + switch (handshake.MessageType) + { + case HandshakeType.ClientHello: + if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, originalMessage, payload)) + { + return false; + } + break; + + case HandshakeType.ClientKeyExchange: + if (peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange) + { + this.Logger.WriteError($"Dropping unexpected ClientKeyExchange message form `{peerAddress}` State({peer.NextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 5) + { + this.Logger.WriteError($"Dropping bad-sequence ClientKeyExchange message from `{peerAddress}` MessageSequence({handshake.MessageSequence})"); + continue; + } + + ByteSpan sharedSecret = new byte[peer.NextEpoch.Handshake.SharedKeySize()]; + if (!peer.NextEpoch.Handshake.VerifyClientMessageAndGenerateSharedKey(sharedSecret, payload)) + { + this.Logger.WriteError($"Dropping malformed ClientKeyExchange message from `{peerAddress}`"); + return false; + } + + // Record incoming ClientKeyExchange message + // to verification stream + peer.NextEpoch.VerificationStream.AddData(originalMessage); + + ByteSpan randomSeed = new byte[2 * Random.Size]; + peer.NextEpoch.ClientRandom.CopyTo(randomSeed); + peer.NextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size)); + + const int MasterSecretSize = 48; + ByteSpan masterSecret = new byte[MasterSecretSize]; + PrfSha256.ExpandSecret( + masterSecret + , sharedSecret + , PrfLabel.MASTER_SECRET + , randomSeed + ); + + // Create the record protection for the upcoming epoch + switch (peer.NextEpoch.SelectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + peer.NextEpoch.RecordProtection = new Aes128GcmRecordProtection( + masterSecret + , peer.NextEpoch.ServerRandom + , peer.NextEpoch.ClientRandom); + break; + + default: + Debug.Assert(false, $"How did we agree to a cipher suite {peer.NextEpoch.SelectedCipherSuite} we can't create?"); + this.Logger.WriteError($"Dropping ClientKeyExchange message from `{peerAddress}` Unsuppored cipher suite"); + return false; + } + + // Generate verification signatures + ByteSpan handshakeStreamHash = new byte[Sha256Stream.DigestSize]; + peer.NextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeStreamHash); + + PrfSha256.ExpandSecret( + peer.NextEpoch.ClientVerification + , masterSecret + , PrfLabel.CLIENT_FINISHED + , handshakeStreamHash + ); + PrfSha256.ExpandSecret( + peer.NextEpoch.ServerVerification + , masterSecret + , PrfLabel.SERVER_FINISHED + , handshakeStreamHash + ); + + + // Update handshake state + masterSecret.SecureClear(); + peer.NextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + break; + + case HandshakeType.Finished: + // Unlike other handshake messages, this is + // for the current epoch - not the next epoch + + // Cannot process a Finished message for + // epoch 0 + if (peer.Epoch == 0) + { + this.Logger.WriteError($"Dropping Finished message for 0-epoch from `{peerAddress}`"); + continue; + } + // Cannot process a Finished message when we + // are negotiating the next epoch + else if (peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + this.Logger.WriteError($"Dropping Finished message while negotiating new epoch from `{peerAddress}`"); + continue; + } + // Cannot process a Finished message without + // verify data + else if (peer.CurrentEpoch.ExpectedClientFinishedVerification.Length != Finished.Size || peer.CurrentEpoch.ServerFinishedVerification.Length != Finished.Size) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How do we have an established non-zero epoch without verify data?"); + + this.Logger.WriteError($"Dropping Finished message (no verify data) from `{peerAddress}`"); + return false; + } + // Cannot process a Finished message without + // record protection for the previous epoch + else if (peer.CurrentEpoch.PreviousRecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How do we have an established non-zero epoch with record protection for the previous epoch?"); + + this.Logger.WriteError($"Dropping Finished message from `{peerAddress}`: No previous epoch record protection"); + return false; + } + + // Verify message sequence + if (handshake.MessageSequence != 6) + { + this.Logger.WriteError($"Dropping bad-sequence Finished message from `{peerAddress}` MessageSequence({handshake.MessageSequence})"); + continue; + } + + // Verify the client has the correct + // handshake sequence + if (payload.Length != Finished.Size) + { + this.Logger.WriteError($"Dropping malformed Finished message from `{peerAddress}`"); + return false; + } + else if (1 != Crypto.Const.ConstantCompareSpans(payload, peer.CurrentEpoch.ExpectedClientFinishedVerification)) + { + +#if DEBUG + this.Logger.WriteError($"Dropping non-verified Finished Handshake from `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonVerifiedFinishedHandshake); +#endif + + // Abort the connection here + // + // The client is either broken, or + // doen not agree on our epoch settings. + // + // Either way, there is not a feasible + // way to progress the connection. + MarkConnectionAsStale(peer.ConnectionId); + this.existingPeers.TryRemove(peerAddress, out _); + + return false; + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + // Describe our ChangeCipherSpec+Finished + Handshake outgoingHandshake = new Handshake(); + outgoingHandshake.MessageType = HandshakeType.Finished; + outgoingHandshake.Length = Finished.Size; + outgoingHandshake.MessageSequence = 7; + outgoingHandshake.FragmentOffset = 0; + outgoingHandshake.FragmentLength = outgoingHandshake.Length; + + Record changeCipherSpecRecord = new Record(); + changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec; + changeCipherSpecRecord.ProtocolVersion = protocolVersion; + changeCipherSpecRecord.Epoch = (ushort)(peer.Epoch - 1); + changeCipherSpecRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + changeCipherSpecRecord.Length = (ushort)peer.CurrentEpoch.PreviousRecordProtection.GetEncryptedSize(ChangeCipherSpec.Size); + ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + + int plaintextFinishedPayloadSize = Handshake.Size + (int)outgoingHandshake.Length; + Record finishedRecord = new Record(); + finishedRecord.ContentType = ContentType.Handshake; + finishedRecord.ProtocolVersion = protocolVersion; + finishedRecord.Epoch = peer.Epoch; + finishedRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + finishedRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(plaintextFinishedPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the flight into wire format + packet = new byte[Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length]; + writer = packet; + changeCipherSpecRecord.Encode(writer); + writer = writer.Slice(Record.Size); + ChangeCipherSpec.Encode(writer); + + ByteSpan startOfFinishedRecord = packet.Slice(Record.Size + changeCipherSpecRecord.Length); + writer = startOfFinishedRecord; + finishedRecord.Encode(writer); + writer = writer.Slice(Record.Size); + outgoingHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + peer.CurrentEpoch.ServerFinishedVerification.CopyTo(writer); + + // Protect the ChangeChipherSpec record + peer.CurrentEpoch.PreviousRecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, changeCipherSpecRecord.Length), + packet.Slice(Record.Size, ChangeCipherSpec.Size), + ref changeCipherSpecRecord + ); + + // Protect the Finished Handshake record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length), + startOfFinishedRecord.Slice(Record.Size, plaintextFinishedPayloadSize), + ref finishedRecord + ); + + // Current epoch can now handle application data + peer.CanHandleApplicationData = true; + + base.QueueRawData(packet, peerAddress); + break; + + // Drop messages that we do not support + case HandshakeType.CertificateVerify: + this.Logger.WriteError($"Dropping unsupported Handshake message from `{peerAddress}` MessageType({handshake.MessageType})"); + continue; + + // Drop messages that originate from the server + case HandshakeType.HelloRequest: + case HandshakeType.ServerHello: + case HandshakeType.HelloVerifyRequest: + case HandshakeType.Certificate: + case HandshakeType.ServerKeyExchange: + case HandshakeType.CertificateRequest: + case HandshakeType.ServerHelloDone: + this.Logger.WriteError($"Dropping server Handshake message from `{peerAddress}` MessageType({handshake.MessageType})"); + continue; + } + } + + return true; + } + + /// <summary> + /// Handle a ClientHello message for a peer + /// </summary> + /// <param name="peer">Originating peer</param> + /// <param name="peerAddress">Peer address</param> + /// <param name="record">Parent record</param> + /// <param name="handshake">Parent Handshake header</param> + /// <param name="payload">Handshake payload</param> + private bool HandleClientHello(PeerData peer, IPEndPoint peerAddress, ref Record record, ref Handshake handshake, ByteSpan originalMessage, ByteSpan payload) + { + // Verify message sequence + if (handshake.MessageSequence != 0) + { + this.Logger.WriteError($"Dropping bad-sequence ClientHello from `{peerAddress}` MessageSequence({handshake.MessageSequence})`"); + return true; + } + + // Make sure we can handle a ClientHello message + if (peer.NextEpoch.State != HandshakeState.ExpectingHello && peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange) + { + // Always handle ClientHello for epoch 0 + if (record.Epoch != 0) + { + this.Logger.WriteError($"Dropping ClientHello from `{peer}` Not expecting ClientHello"); + return true; + } + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + if (!ClientHello.Parse(out ClientHello clientHello, protocolVersion, payload)) + { + this.Logger.WriteError($"Dropping malformed ClientHello Handshake message from `{peerAddress}`"); + return false; + } + + // Find an acceptable cipher suite we can use + CipherSuite selectedCipherSuite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256; + if (!clientHello.ContainsCipherSuite(CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256) || !clientHello.ContainsCurve(NamedCurve.x25519)) + { + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` No compatible cipher suite"); + return false; + } + + // If this message was not signed by us, + // request a signed message before doing anything else + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac)) + { + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac)) + { + ulong outgoingSequence = 1; + IRecordProtection recordProtection = NullRecordProtection.Instance; + if (record.Epoch != 0) + { + outgoingSequence = peer.CurrentEpoch.NextExpectedSequence; + ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + + recordProtection = peer.CurrentEpoch.RecordProtection; + } + +#if DEBUG + this.Logger.WriteError($"Sending HelloVerifyRequest to peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.PeerVerifyHelloRequests); +#endif + this.SendHelloVerifyRequest(peerAddress, outgoingSequence, record.Epoch, recordProtection, protocolVersion); + return true; + } + } + + // Client is initiating a brand new connection. We need + // to destroy the existing connection and establish a + // new session. + if (record.Epoch == 0 && peer.Epoch != 0) + { + ConnectionId oldConnectionId = peer.ConnectionId; + peer.ResetPeer(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1); + + // Inform the parent layer that the existing + // connection should be abandoned. + MarkConnectionAsStale(oldConnectionId); + } + + // Determine if this is an original message, or a retransmission + bool recordMessagesForVerifyData = false; + if (peer.NextEpoch.State == HandshakeState.ExpectingHello) + { + // Create our handhake cipher suite + IHandshakeCipherSuite handshakeCipherSuite = null; + switch (selectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + if (clientHello.ContainsCurve(NamedCurve.x25519)) + { + handshakeCipherSuite = new X25519EcdheRsaSha256(this.random); + } + else + { + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 cipher suite"); + return false; + } + + break; + + default: + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create handshake cipher suite"); + return false; + } + + peer.Session = clientHello.Session; + + // Update the state of our epoch transition + peer.NextEpoch.Epoch = (ushort)(record.Epoch + 1); + peer.NextEpoch.State = HandshakeState.ExpectingClientKeyExchange; + peer.NextEpoch.SelectedCipherSuite = selectedCipherSuite; + peer.NextEpoch.Handshake = handshakeCipherSuite; + clientHello.Random.CopyTo(peer.NextEpoch.ClientRandom); + peer.NextEpoch.ServerRandom.FillWithRandom(this.random); + recordMessagesForVerifyData = true; + +#if DEBUG + this.Logger.WriteVerbose($"ClientRandom: {peer.NextEpoch.ClientRandom} ServerRandom: {peer.NextEpoch.ServerRandom}"); +#endif + + // Copy the original ClientHello + // handshake to our verification stream + peer.NextEpoch.VerificationStream.AddData( + originalMessage.Slice( + 0 + , Handshake.Size + (int)handshake.Length + ) + ); + } + + // The initial record flight from the server + // contains the following Handshake messages: + // * ServerHello + // * Certificate + // * ServerKeyExchange + // * ServerHelloDone + // + // The Certificate message is almost always + // too large to fit into a single datagram, + // so it is pre-fragmented + // (see `SetCertificates`). Therefore, we + // need to send multiple record packets for + // this flight. + // + // The first record contains the ServerHello + // handshake message, as well as the first + // portion of the Certificate message. + // + // We then send a record packet until the + // entire Certificate message has been sent + // to the client. + // + // The final record packet contains the + // ServerKeyExchange and the ServerHelloDone + // messages. + + // Describe first record of the flight + ServerHello serverHello = new ServerHello(); + serverHello.ServerProtocolVersion = protocolVersion; + serverHello.Random = peer.NextEpoch.ServerRandom; + serverHello.CipherSuite = selectedCipherSuite; + + Handshake serverHelloHandshake = new Handshake(); + serverHelloHandshake.MessageType = HandshakeType.ServerHello; + serverHelloHandshake.Length = ServerHello.MinSize; + serverHelloHandshake.MessageSequence = 1; + serverHelloHandshake.FragmentOffset = 0; + serverHelloHandshake.FragmentLength = serverHelloHandshake.Length; + + int maxCertFragmentSize = peer.Session.Version == 0 ? MaxCertFragmentSizeV0 : MaxCertFragmentSizeV1; + + // The first certificate data needs to leave room for + // * Record header + // * ServerHello header + // * ServerHello payload + // * Certificate header + + var certificateData = this.encodedCertificate; + int initialCertPadding = Record.Size + Handshake.Size + serverHello.Size + Handshake.Size; + int certInitialFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - initialCertPadding); + + Handshake certificateHandshake = new Handshake(); + certificateHandshake.MessageType = HandshakeType.Certificate; + certificateHandshake.Length = (uint)certificateData.Length; + certificateHandshake.MessageSequence = 2; + certificateHandshake.FragmentOffset = 0; + certificateHandshake.FragmentLength = (uint)certInitialFragmentSize; + + int initialRecordPayloadSize = 0 + + Handshake.Size + serverHello.Size + + Handshake.Size + (int)certificateHandshake.FragmentLength + ; + Record initialRecord = new Record(); + initialRecord.ContentType = ContentType.Handshake; + initialRecord.ProtocolVersion = protocolVersion; + initialRecord.Epoch = peer.Epoch; + initialRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + initialRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(initialRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert initial record of the flight to + // wire format + ByteSpan packet = new byte[Record.Size + initialRecord.Length]; + ByteSpan writer = packet; + initialRecord.Encode(writer); + writer = writer.Slice(Record.Size); + serverHelloHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + serverHello.Encode(writer); + writer = writer.Slice(ServerHello.MinSize); + certificateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + certificateData.Slice(0, certInitialFragmentSize).CopyTo(writer); + certificateData = certificateData.Slice(certInitialFragmentSize); + + // Protect initial record of the flight + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, initialRecord.Length), + packet.Slice(Record.Size, initialRecordPayloadSize), + ref initialRecord + ); + + base.QueueRawData(packet, peerAddress); + + // Record record payload for verification + if (recordMessagesForVerifyData) + { + Handshake fullCeritficateHandshake = certificateHandshake; + fullCeritficateHandshake.FragmentLength = fullCeritficateHandshake.Length; + + packet = new byte[Handshake.Size + ServerHello.MinSize + Handshake.Size]; + writer = packet; + serverHelloHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + serverHello.Encode(writer); + writer = writer.Slice(ServerHello.MinSize); + fullCeritficateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + peer.NextEpoch.VerificationStream.AddData(packet); + peer.NextEpoch.VerificationStream.AddData(this.encodedCertificate); + } + + // Process additional certificate records + // Subsequent certificate data needs to leave room for + // * Record header + // * Certificate header + const int CertPadding = Record.Size + Handshake.Size; + while (certificateData.Length > 0) + { + int certFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - CertPadding); + + certificateHandshake.FragmentOffset += certificateHandshake.FragmentLength; + certificateHandshake.FragmentLength = (uint)certFragmentSize; + + int additionalRecordPayloadSize = Handshake.Size + (int)certificateHandshake.FragmentLength; + Record additionalRecord = new Record(); + additionalRecord.ContentType = ContentType.Handshake; + additionalRecord.ProtocolVersion = protocolVersion; + additionalRecord.Epoch = peer.Epoch; + additionalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + additionalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(additionalRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert record to wire format + packet = new byte[Record.Size + additionalRecord.Length]; + writer = packet; + additionalRecord.Encode(writer); + writer = writer.Slice(Record.Size); + certificateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + certificateData.Slice(0, certFragmentSize).CopyTo(writer); + + certificateData = certificateData.Slice(certFragmentSize); + + // Protect record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, additionalRecord.Length), + packet.Slice(Record.Size, additionalRecordPayloadSize), + ref additionalRecord + ); + + base.QueueRawData(packet, peerAddress); + } + + // Describe final record of the flight + Handshake serverKeyExchangeHandshake = new Handshake(); + serverKeyExchangeHandshake.MessageType = HandshakeType.ServerKeyExchange; + serverKeyExchangeHandshake.Length = (uint)peer.NextEpoch.Handshake.CalculateServerMessageSize(this.certificatePrivateKey); + serverKeyExchangeHandshake.MessageSequence = 3; + serverKeyExchangeHandshake.FragmentOffset = 0; + serverKeyExchangeHandshake.FragmentLength = serverKeyExchangeHandshake.Length; + + Handshake serverHelloDoneHandshake = new Handshake(); + serverHelloDoneHandshake.MessageType = HandshakeType.ServerHelloDone; + serverHelloDoneHandshake.Length = 0; + serverHelloDoneHandshake.MessageSequence = 4; + serverHelloDoneHandshake.FragmentOffset = 0; + serverHelloDoneHandshake.FragmentLength = 0; + + int finalRecordPayloadSize = 0 + + Handshake.Size + (int)serverKeyExchangeHandshake.Length + + Handshake.Size + (int)serverHelloDoneHandshake.Length + ; + Record finalRecord = new Record(); + finalRecord.ContentType = ContentType.Handshake; + finalRecord.ProtocolVersion = protocolVersion; + finalRecord.Epoch = peer.Epoch; + finalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + finalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(finalRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert final record of the flight to wire + // format + packet = new byte[Record.Size + finalRecord.Length]; + writer = packet; + finalRecord.Encode(writer); + writer = writer.Slice(Record.Size); + serverKeyExchangeHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + peer.NextEpoch.Handshake.EncodeServerKeyExchangeMessage(writer, this.certificatePrivateKey); + writer = writer.Slice((int)serverKeyExchangeHandshake.Length); + serverHelloDoneHandshake.Encode(writer); + + // Record record payload for verification + if (recordMessagesForVerifyData) + { + peer.NextEpoch.VerificationStream.AddData( + packet.Slice( + packet.Offset + Record.Size + , finalRecordPayloadSize + ) + ); + } + + // Protect final record of the flight + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, finalRecord.Length), + packet.Slice(Record.Size, finalRecordPayloadSize), + ref finalRecord + ); + + base.QueueRawData(packet, peerAddress); + + return true; + } + + /// <summary> + /// Handle an incoming packet that is not tied to an existing peer + /// </summary> + /// <param name="message">Incoming datagram</param> + /// <param name="peerAddress">Originating address</param> + private void HandleNonPeerRecord(ByteSpan message, IPEndPoint peerAddress) + { + Record record; + if (!Record.Parse(out record, expectedProtocolVersion: null, message)) + { + this.Logger.WriteError($"Dropping malformed record from non-peer `{peerAddress}`"); + return; + } + message = message.Slice(Record.Size); + + // The protocol only supports receiving a single record + // from a non-peer. + if (record.Length != message.Length) + { + // NOTE(mendsley): This isn't always fatal. + // However, this is an indication that something + // fishy is going on. In the best case, there's a + // bug on the client or in the UDP stack (some + // stacks don't both to verify the checksum). In the + // worst case we're dealing with a malicious actor. + // In the malicious case, we'll end up dropping the + // connection later in the process. + if (message.Length < record.Length) + { + this.Logger.WriteInfo($"Dropping bad record from non-peer `{peerAddress}`. Msg length {message.Length} < {record.Length}"); + return; + } + } + + // We only accept zero-epoch records from non-peers + if (record.Epoch != 0) + { + ///NOTE(mendsley): Not logging anything here, as + /// this could easily be latent data arriving from a + /// recently disconnected peer. + return; + } + + // We only accept Handshake protocol messages from non-peers + if (record.ContentType != ContentType.Handshake) + { + this.Logger.WriteError($"Dropping non-handhsake message from non-peer `{peerAddress}`"); + return; + } + + ByteSpan originalMessage = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.Logger.WriteError($"Dropping malformed handshake message from non-peer `{peerAddress}`"); + return; + } + + // We only accept ClientHello messages from non-peers + if (handshake.MessageType != HandshakeType.ClientHello) + { +#if DEBUG + this.Logger.WriteError($"Dropping non-ClientHello ({handshake.MessageType}) message from non-peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonPeerNonHelloPacketsDropped); +#endif + return; + } + message = message.Slice(Handshake.Size); + + if (!ClientHello.Parse(out ClientHello clientHello, expectedProtocolVersion: null, message)) + { + this.Logger.WriteError($"Dropping malformed ClientHello message from non-peer `{peerAddress}`"); + return; + } + + // If this ClientHello is not signed by us, request the + // client send us a signed message + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac)) + { + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac)) + { +#if DEBUG + this.Logger.WriteVerbose($"Sending HelloVerifyRequest to non-peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonPeerVerifyHelloRequests); +#endif + this.SendHelloVerifyRequest(peerAddress, 1, 0, NullRecordProtection.Instance, clientHello.ClientProtocolVersion); + return; + } + } + + // Allocate state for the new peer and register it + PeerData peer = new PeerData(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1, clientHello.ClientProtocolVersion); + this.ProcessHandshake(peer, peerAddress, ref record, originalMessage); + this.existingPeers[peerAddress] = peer; + } + + //Send a HelloVerifyRequest handshake message to a peer + private void SendHelloVerifyRequest(IPEndPoint peerAddress, ulong recordSequence, ushort epoch, IRecordProtection recordProtection, ProtocolVersion protocolVersion) + { + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.HelloVerifyRequest; + handshake.Length = HelloVerifyRequest.Size; + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + int plaintextPayloadSize = Handshake.Size + (int)handshake.Length; + + Record record = new Record(); + record.ContentType = ContentType.Handshake; + record.ProtocolVersion = protocolVersion; + record.Epoch = epoch; + record.SequenceNumber = recordSequence; + record.Length = (ushort)recordProtection.GetEncryptedSize(plaintextPayloadSize); + + // Encode record to wire format + ByteSpan packet = new byte[Record.Size + record.Length]; + ByteSpan writer = packet; + record.Encode(writer); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + HelloVerifyRequest.Encode(writer, peerAddress, this.CurrentCookieHmac, protocolVersion); + + // Protect record payload + recordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, record.Length), + packet.Slice(Record.Size, plaintextPayloadSize), + ref record + ); + + base.QueueRawData(packet, peerAddress); + } + + /// <summary> + /// Handle a requrest to send a datagram to the network + /// </summary> + protected override void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint) + { + PeerData peer; + if (!this.existingPeers.TryGetValue(remoteEndPoint, out peer)) + { + // Drop messages if we don't know how to send them + return; + } + + lock (peer) + { + // If we're negotiating a new epoch, queue data + if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + ByteSpan copyOfSpan = new byte[span.Length]; + span.CopyTo(copyOfSpan); + + peer.QueuedApplicationDataMessage.Add(copyOfSpan); + return; + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + // Send any queued application data now + for (int ii = 0, nn = peer.QueuedApplicationDataMessage.Count; ii != nn; ++ii) + { + ByteSpan queuedSpan = peer.QueuedApplicationDataMessage[ii]; + + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.ApplicationData; + outgoingRecord.ProtocolVersion = protocolVersion; + outgoingRecord.Epoch = peer.Epoch; + outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(queuedSpan.Length); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(writer); + writer = writer.Slice(Record.Size); + queuedSpan.CopyTo(writer); + + // Protect the record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, queuedSpan.Length), + ref outgoingRecord + ); + + base.QueueRawData(packet, remoteEndPoint); + } + peer.QueuedApplicationDataMessage.Clear(); + + { + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.ApplicationData; + outgoingRecord.ProtocolVersion = protocolVersion; + outgoingRecord.Epoch = peer.Epoch; + outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(span.Length); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(writer); + writer = writer.Slice(Record.Size); + span.CopyTo(writer); + + // Protect the record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, span.Length), + ref outgoingRecord + ); + + base.QueueRawData(packet, remoteEndPoint); + } + } + } + + private void HandleStaleConnections(object _) + { + TimeSpan maxAge = TimeSpan.FromSeconds(2.5f); + DateTime now = DateTime.UtcNow; + foreach (KeyValuePair<IPEndPoint, PeerData> kvp in this.existingPeers) + { + PeerData peer = kvp.Value; + lock (peer) + { + if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + TimeSpan negotiationAge = now - peer.StartOfNegotiation; + if (negotiationAge > maxAge) + { + MarkConnectionAsStale(peer.ConnectionId); + } + } + } + } + + ConnectionId connectionId; + while (this.staleConnections.TryPop(out connectionId)) + { + ThreadLimitedUdpServerConnection connection; + if (this.allConnections.TryGetValue(connectionId, out connection)) + { + connection.Disconnect("Stale Connection", null); + } + } + } + + protected void MarkConnectionAsStale(ConnectionId connectionId) + { + if (this.allConnections.ContainsKey(connectionId)) + { + this.staleConnections.Push(connectionId); + } + } + + /// <inheritdoc /> + internal override void RemovePeerRecord(ConnectionId connectionId) + { + if (this.existingPeers.TryRemove(connectionId.EndPoint, out var peer)) + { + peer.Dispose(); + } + } + + /// <summary> + /// Allocate a new connection id + /// </summary> + private ConnectionId AllocateConnectionId(IPEndPoint endPoint) + { + int rawSerialId = Interlocked.Increment(ref this.connectionSerial_unsafe); + return ConnectionId.Create(endPoint, rawSerialId); + } + + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs new file mode 100644 index 0000000..4da2051 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs @@ -0,0 +1,1246 @@ +using Hazel.Crypto; +using Hazel.Udp; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace Hazel.Dtls +{ + /// <summary> + /// Connects to a UDP-DTLS server + /// </summary> + /// <inheritdoc /> + public class DtlsUnityConnection : UnityUdpClientConnection + { + /// <summary> + /// Current state of the handshake sequence + /// </summary> + enum HandshakeState + { + Initializing, + + ExpectingServerHello, + ExpectingCertificate, + ExpectingServerKeyExchange, + ExpectingServerHelloDone, + + ExpectingChangeCipherSpec, + ExpectingFinished, + + Established, + } + + /// <summary> + /// State data for the current epoch + /// </summary> + struct CurrentEpoch + { + public ulong NextOutgoingSequence; + + public ulong NextExpectedSequence; + public ulong PreviousSequenceWindowBitmask; + + public IRecordProtection RecordProtection; + } + + struct FragmentRange + { + public int Offset; + public int Length; + } + + /// <summary> + /// State data for the next epoch + /// </summary> + struct NextEpoch + { + public ushort Epoch; + + public HandshakeState State; + + public ulong NextOutgoingSequence; + + public DateTime NegotiationStartTime; + public DateTime NextPacketResendTime; + public int PacketResendCount; + + public CipherSuite SelectedCipherSuite; + public IRecordProtection RecordProtection; + public IHandshakeCipherSuite Handshake; + public ByteSpan Cookie; + public Sha256Stream VerificationStream; + public RSA ServerPublicKey; + + public ByteSpan ClientRandom; + public ByteSpan ServerRandom; + + public ByteSpan MasterSecret; + public ByteSpan ServerVerification; + + public List<FragmentRange> CertificateFragments; + public ByteSpan CertificatePayload; + } + + struct QueuedAppData + { + public byte[] Bytes; + public byte SendOption; + public Action AckCallback; + } + + private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP; + + internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion; + + private readonly object syncRoot = new object(); + private readonly RandomNumberGenerator random = RandomNumberGenerator.Create(); + + private ushort epoch; + private CurrentEpoch currentEpoch; + private NextEpoch nextEpoch; + private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200); + + private readonly Queue<QueuedAppData> queuedApplicationData = new Queue<QueuedAppData>(); + + private X509Certificate2Collection serverCertificates = new X509Certificate2Collection(); + + public bool HandshakeComplete + { + get + { + lock (this.syncRoot) + { + return this.nextEpoch.State == HandshakeState.Established; + } + } + } + + /// <summary> + /// Create a new instance of the DTLS connection + /// </summary> + /// <inheritdoc /> + public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) + : base(logger, remoteEndPoint, ipMode) + { + this.nextEpoch.ServerRandom = new byte[Random.Size]; + this.nextEpoch.ClientRandom = new byte[Random.Size]; + this.nextEpoch.ServerVerification = new byte[Finished.Size]; + this.nextEpoch.CertificateFragments = new List<FragmentRange>(); + + this.ResetConnectionState(); + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + lock (this.syncRoot) + { + this.ResetConnectionState(); + } + } + + /// <summary> + /// Set the list of valid server certificates + /// </summary> + /// <param name="certificateCollection"> + /// List of certificates of authentic servers + /// </param> + public void SetValidServerCertificates(X509Certificate2Collection certificateCollection) + { + lock (this.syncRoot) + { + foreach (X509Certificate2 certificate in certificateCollection) + { + if (!(certificate.PublicKey.Key is RSA)) + { + throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection)); + } + } + + this.serverCertificates = certificateCollection; + } + } + + /// <summary> + /// Set the packet resend timer for handshake messages + /// </summary> + public void SetHandshakeResendTimeout(TimeSpan timeout) + { + lock (this.syncRoot) + { + this.handshakeResendTimeout = timeout; + } + } + + /// <summary> + /// Reset existing connection state + /// </summary> + private void ResetConnectionState() + { + this.currentEpoch.NextOutgoingSequence = 1; + this.currentEpoch.NextExpectedSequence = 1; + this.currentEpoch.PreviousSequenceWindowBitmask = 0; + this.currentEpoch.RecordProtection?.Dispose(); + this.currentEpoch.RecordProtection = NullRecordProtection.Instance; + + this.nextEpoch.Epoch = 1; + this.nextEpoch.State = HandshakeState.Initializing; + this.nextEpoch.NextOutgoingSequence = 1; + this.nextEpoch.NegotiationStartTime = DateTime.MinValue; + this.nextEpoch.NextPacketResendTime = DateTime.MinValue; + this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; + this.nextEpoch.RecordProtection?.Dispose(); + this.nextEpoch.RecordProtection = null; + this.nextEpoch.Handshake?.Dispose(); + this.nextEpoch.Handshake = null; + this.nextEpoch.Cookie = ByteSpan.Empty; + this.nextEpoch.VerificationStream?.Dispose(); + this.nextEpoch.VerificationStream = new Sha256Stream(); + this.nextEpoch.ServerPublicKey = null; + this.nextEpoch.ServerRandom.SecureClear(); + this.nextEpoch.ClientRandom.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + this.nextEpoch.ServerVerification.SecureClear(); + this.nextEpoch.CertificateFragments.Clear(); + this.nextEpoch.CertificatePayload = ByteSpan.Empty; + + this.epoch = 0; + while (this.queuedApplicationData.TryDequeue(out _)) ; + } + + /// <summary> + /// Abort the existing connection and restart the process + /// </summary> + protected override void RestartConnection() + { + lock (this.syncRoot) + { + this.ResetConnectionState(); + this.nextEpoch.ClientRandom.FillWithRandom(this.random); + this.SendClientHello(isRetransmit: false); + } + + base.RestartConnection(); + } + + /// <inheritdoc /> + protected override void ResendPacketsIfNeeded() + { + lock (this.syncRoot) + { + // Check if we need to resend handshake message + if (this.nextEpoch.State != HandshakeState.Established) + { + DateTime now = DateTime.UtcNow; + if (now >= this.nextEpoch.NextPacketResendTime) + { + double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds; + this.nextEpoch.PacketResendCount++; + + if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit) + || negotiationDurationMs > this.DisconnectTimeoutMs) + { + this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms)."); + } + else + { + switch (this.nextEpoch.State) + { + case HandshakeState.ExpectingServerHello: + case HandshakeState.ExpectingCertificate: + case HandshakeState.ExpectingServerKeyExchange: + case HandshakeState.ExpectingServerHelloDone: + this.SendClientHello(isRetransmit: true); + break; + + case HandshakeState.ExpectingChangeCipherSpec: + case HandshakeState.ExpectingFinished: + this.SendClientKeyExchangeFlight(isRetransmit: true); + break; + + case HandshakeState.Established: + default: + break; + } + } + } + } + } + + base.ResendPacketsIfNeeded(); + } + + /// <summary> + /// Flush any queued application data packets + /// </summary> + private void FlushQueuedApplicationData() + { + while (this.queuedApplicationData.TryDequeue(out var queuedData)) + { + base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback); + } + } + + /// <summary> + /// Request from the application to write data to the DTLS + /// stream. If appropriate, returns a byte span to send to + /// the wire. + /// </summary> + /// <param name="bytes">Plaintext bytes to write</param> + /// <param name="length">Length of the bytes to write</param> + /// <returns> + /// Encrypted data to put on the wire if appropriate, + /// otherwise an empty span + /// </returns> + private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length) + { + lock (this.syncRoot) + { + Record outgoinRecord = new Record(); + outgoinRecord.ContentType = ContentType.ApplicationData; + outgoinRecord.ProtocolVersion = DtlsVersion; + outgoinRecord.Epoch = this.epoch; + outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length); + ++this.currentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoinRecord.Length]; + ByteSpan writer = packet; + outgoinRecord.Encode(writer); + writer = writer.Slice(Record.Size); + new ByteSpan(bytes, 0, length).CopyTo(writer); + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoinRecord.Length), + packet.Slice(Record.Size, length), + ref outgoinRecord + ); + + return packet; + } + } + + protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null) + { + lock (this.syncRoot) + { + // If we're negotiating a new epoch, queue data + if (this.nextEpoch.State != HandshakeState.Established) + { + this.queuedApplicationData.Enqueue(new QueuedAppData + { + Bytes = data, + SendOption = sendOption, + AckCallback = ackCallback + }); + + return; + } + } + + base.HandleSend(data, sendOption, ackCallback); + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { + ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); + if (wireData.Length > 0) + { + Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); + base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length); + } + } + + /// <inheritdoc /> + protected override void WriteBytesToConnectionSync(byte[] bytes, int length) + { + ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); + if (wireData.Length > 0) + { + Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); + base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length); + } + } + + /// <inheritdoc /> + protected internal override void HandleReceive(MessageReader reader, int bytesReceived) + { + ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining); + lock (this.syncRoot) + { + this.HandleReceive(message); + } + + reader.Recycle(); + } + + /// <summary> + /// Handle an incoming datagram + /// </summary> + /// <param name="span">Bytes of the datagram</param> + private void HandleReceive(ByteSpan span) + { + // Each incoming packet may contain multiple DTLS + // records + while (span.Length > 0) + { + Record record; + if (!Record.Parse(out record, DtlsVersion, span)) + { + this.logger.WriteError("Dropping malformed record"); + return; + } + span = span.Slice(Record.Size); + + if (span.Length < record.Length) + { + this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})"); + return; + } + + ByteSpan recordPayload = span.Slice(0, record.Length); + span = span.Slice(record.Length); + + // Early out and drop ApplicationData records + if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established) + { + this.logger.WriteError("Dropping ApplicationData record. Cannot process yet"); + continue; + } + + // Drop records from a different epoch + if (record.Epoch != this.epoch) + { + this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})"); + continue; + } + + // Prevent replay attacks by dropping records + // we've already processed + int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1); + ulong windowMask = 1ul << windowIndex; + if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence) + { + if (windowIndex >= 64) + { + this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})"); + continue; + } + + if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0) + { + this.logger.WriteWarning("Dropping duplicate record"); + continue; + } + } + + // Verify record authenticity + int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length); + ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize); + + if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record)) + { + this.logger.WriteError("Dropping non-authentic record"); + return; + } + + recordPayload = decryptedPayload; + + // Update out sequence number bookkeeping + if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence) + { + int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence); + this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift; + this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1; + } + else + { + this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask; + } + + // This is handy for debugging, but too verbose even for verbose. + // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})"); + switch (record.ContentType) + { + case ContentType.ChangeCipherSpec: + if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec) + { + this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})"); + break; + } + else if (this.nextEpoch.RecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client. + Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?"); + break; + } + + if (!ChangeCipherSpec.Parse(recordPayload)) + { + this.logger.WriteError("Dropping malformed ChangeCipherSpec message"); + break; + } + + // Migrate to the next epoch + this.epoch = this.nextEpoch.Epoch; + this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection; + this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence; + this.currentEpoch.NextExpectedSequence = 1; + this.currentEpoch.PreviousSequenceWindowBitmask = 0; + + this.nextEpoch.State = HandshakeState.ExpectingFinished; + this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; + this.nextEpoch.RecordProtection = null; + this.nextEpoch.Handshake?.Dispose(); + this.nextEpoch.Cookie = ByteSpan.Empty; + this.nextEpoch.VerificationStream.Reset(); + this.nextEpoch.ServerPublicKey = null; + this.nextEpoch.ServerRandom.SecureClear(); + this.nextEpoch.ClientRandom.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + break; + + case ContentType.Alert: + this.logger.WriteError("Dropping unsupported alert record"); + continue; + + case ContentType.Handshake: + if (!ProcessHandshake(ref record, recordPayload)) + { + return; + } + break; + + case ContentType.ApplicationData: + // Forward data to the application + MessageReader reader = MessageReader.GetSized(recordPayload.Length); + reader.Length = recordPayload.Length; + recordPayload.CopyTo(reader.Buffer); + + base.HandleReceive(reader, recordPayload.Length); + break; + } + } + } + + /// <summary> + /// Process an incoming Handshake protocol message + /// </summary> + /// <param name="record">Parent record</param> + /// <param name="message">Record payload</param> + /// <returns> + /// True if further processing of the underlying datagram + /// should be continues. Otherwise, false. + /// </returns> + private bool ProcessHandshake(ref Record record, ByteSpan message) + { + // Each record may have multiple Handshake messages + while (message.Length > 0) + { + ByteSpan originalPayload = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.logger.WriteError("Dropping malformed handshake message"); + return false; + } + message = message.Slice(Handshake.Size); + + // Check for fragmented messages + if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length) + { + // We only support fragmentation on Certificate messages + if (handshake.MessageType != HandshakeType.Certificate) + { + this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})"); + continue; + } + + if (message.Length < handshake.FragmentLength) + { + this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})"); + return false; + } + + originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength)); + message = message.Slice((int)handshake.FragmentLength); + } + else + { + if (message.Length < handshake.Length) + { + this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})"); + return false; + } + + originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length)); + message = message.Slice((int)handshake.Length); + } + + ByteSpan payload = originalPayload.Slice(Handshake.Size); + +#if DEBUG + this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})"); +#endif + switch (handshake.MessageType) + { + case HandshakeType.HelloVerifyRequest: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) + { + this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 0) + { + this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})"); + continue; + } + + HelloVerifyRequest helloVerifyRequest; + if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload)) + { + this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message"); + continue; + } + + // If the cookie differs, save it and restart the handshake + if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length + && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1) + { + this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message"); + continue; + } + + this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length]; + helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie); + this.nextEpoch.ClientRandom.FillWithRandom(this.random); + + // We don't need to resend here. We already have the cookie so we already sent it once. + this.SendClientHello(isRetransmit: false); + + break; + + case HandshakeType.ServerHello: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) + { + this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 1) + { + this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})"); + continue; + } + + ServerHello serverHello; + if (!ServerHello.Parse(out serverHello, payload)) + { + this.logger.WriteError("Dropping malformed ServerHello message"); + continue; + } + + switch (serverHello.CipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random); + break; + + default: + this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})"); + continue; + } + + // Save server parameters + this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite; + serverHello.Random.CopyTo(this.nextEpoch.ServerRandom); + this.nextEpoch.State = HandshakeState.ExpectingCertificate; + this.nextEpoch.CertificateFragments.Clear(); + this.nextEpoch.CertificatePayload = ByteSpan.Empty; + +#if DEBUG + this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}"); +#endif + + // Append ServerHelllo message to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + break; + + case HandshakeType.Certificate: + if (this.nextEpoch.State != HandshakeState.ExpectingCertificate) + { + this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 2) + { + this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})"); + continue; + } + + // If this is a fragmented message + if (handshake.FragmentLength != handshake.Length) + { + if (this.nextEpoch.CertificatePayload.Length != handshake.Length) + { + this.nextEpoch.CertificatePayload = new byte[handshake.Length]; + this.nextEpoch.CertificateFragments.Clear(); + } + + // Should we add this fragment? + // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments + // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt? + // So let's just ignore that and assume that the sender always wants to send the same fragments. + if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength)) + { + continue; + } + + payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength)); + this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength }); + this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => { + return lhs.Offset.CompareTo(rhs.Offset); + }); + + // Have we completed the message? + int currentOffset = 0; + bool valid = true; + foreach (FragmentRange range in this.nextEpoch.CertificateFragments) + { + if (range.Offset != currentOffset) + { + valid = false; + break; + } + + currentOffset += range.Length; + } + + if (currentOffset != this.nextEpoch.CertificatePayload.Length) + { + valid = false; + } + + // Still waiting on more fragments? + if (!valid) + { + continue; + } + + // Replace the message payload, and continue + this.nextEpoch.CertificateFragments.Clear(); + payload = this.nextEpoch.CertificatePayload; + } + + X509Certificate2 certificate; + if (!Certificate.Parse(out certificate, payload)) + { + this.logger.WriteError("Dropping malformed Certificate message"); + continue; + } + + // Verify the certificate is authenticate + if (!this.serverCertificates.Contains(certificate)) + { + this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic"); + continue; + } + + RSA publicKey = certificate.PublicKey.Key as RSA; + if (publicKey == null) + { + this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed"); + continue; + } + + // Add the final Certificate message to the verification stream + Handshake fullCertificateHandhake = handshake; + fullCertificateHandhake.FragmentOffset = 0; + fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length; + + ByteSpan serializedCertificateHandshake = new byte[Handshake.Size]; + fullCertificateHandhake.Encode(serializedCertificateHandshake); + this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake); + this.nextEpoch.VerificationStream.AddData(payload); + + this.nextEpoch.ServerPublicKey = publicKey; + this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange; + break; + + case HandshakeType.ServerKeyExchange: + if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange) + { + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})"); + continue; + } + else if (this.nextEpoch.ServerPublicKey == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client + Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?"); + + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key"); + continue; + } + else if (this.nextEpoch.Handshake == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client + Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?"); + + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface"); + continue; + } + else if (handshake.MessageSequence != 3) + { + this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})"); + continue; + } + + ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()]; + if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey)) + { + this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage"); + return false; + } + + // Generate the session master secret + ByteSpan randomSeed = new byte[2 * Random.Size]; + this.nextEpoch.ClientRandom.CopyTo(randomSeed); + this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size)); + + const int MasterSecretSize = 48; + ByteSpan masterSecret = new byte[MasterSecretSize]; + PrfSha256.ExpandSecret( + masterSecret + , sharedSecret + , PrfLabel.MASTER_SECRET + , randomSeed + ); + + // Create record protection for the upcoming epoch + switch (this.nextEpoch.SelectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection( + masterSecret + , this.nextEpoch.ServerRandom + , this.nextEpoch.ClientRandom + ); + break; + + default: + ///NOTE(mendsley): this _should_ not + /// happen on a well-formed client. + Debug.Assert(false, "SeverHello processing already approved this ciphersuite"); + + this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection"); + return false; + } + + this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone; + this.nextEpoch.MasterSecret = masterSecret; + + // Append ServerKeyExchange to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + break; + + case HandshakeType.ServerHelloDone: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone) + { + this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 4) + { + this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})"); + continue; + } + + this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + + // Append ServerHelloDone to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + + this.SendClientKeyExchangeFlight(isRetransmit: false); + break; + + case HandshakeType.Finished: + if (this.nextEpoch.State != HandshakeState.ExpectingFinished) + { + this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})"); + continue; + } + else if (payload.Length != Finished.Size) + { + this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})"); + continue; + } + else if (handshake.MessageSequence != 7) + { + this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})"); + continue; + } + + // Verify the digest from the server + if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification)) + { + this.logger.WriteError("Dropping non-verified Finished handshake message"); + return false; + } + + ++this.nextEpoch.Epoch; + this.nextEpoch.State = HandshakeState.Established; + this.nextEpoch.NegotiationStartTime = DateTime.MinValue; + this.nextEpoch.NextPacketResendTime = DateTime.MinValue; + this.nextEpoch.ServerVerification.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + + this.FlushQueuedApplicationData(); + break; + + // Drop messages we do not support + case HandshakeType.CertificateRequest: + case HandshakeType.HelloRequest: + this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})"); + break; + + // Drop messages that originate from the client + case HandshakeType.ClientHello: + case HandshakeType.ClientKeyExchange: + case HandshakeType.CertificateVerify: + this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})"); + break; + } + } + + return true; + } + + private bool IsFragmentOverlapping(List<FragmentRange> fragments, uint newOffset, uint newLength) + { + foreach (var frag in fragments) + { + // New fragment overlaps an existing one + if (newOffset <= frag.Offset + && frag.Offset < newOffset + newLength) + { + return true; + } + + // Existing fragment overlaps this new one + if (frag.Offset <= newOffset + && newOffset < frag.Offset + frag.Length) + { + return true; + } + } + + return false; + } + + /// <summary> + /// Send (resend) a ClientHello message to the server + /// </summary> + protected virtual void SendClientHello(bool isRetransmit) + { +#if DEBUG + var verb = isRetransmit ? "Resending" : "Sending"; + this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}"); +#endif + + // Describe our ClientHello flight + ClientHello clientHello = new ClientHello(); + clientHello.ClientProtocolVersion = DtlsVersion; + clientHello.Random = this.nextEpoch.ClientRandom; + clientHello.Cookie = this.nextEpoch.Cookie; + clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion); + clientHello.CipherSuites = new byte[2]; + clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); + clientHello.SupportedCurves = new byte[2]; + clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); + + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.ClientHello; + handshake.Length = (uint)clientHello.CalculateSize(); + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + // Describe the record + int plaintextLength = (int)(Handshake.Size + handshake.Length); + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.Handshake; + outgoingRecord.ProtocolVersion = DtlsVersion; + outgoingRecord.Epoch = this.epoch; + outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); + ++this.currentEpoch.NextOutgoingSequence; + + // Convert the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(packet); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + clientHello.Encode(writer); + + // If this is our first valid attempt at contacting the server: + // - Reset our verification stream + // - Write ClientHello to the verification stream + // - We next expect a ServerHello + // + // ClientHello+Cookie triggers many sequential packets in response + // It's important to make forward progress as the packets may be reordered in-flight + // But with enough resends, we will read them all in an appropriate order + if (!isRetransmit) + { + this.nextEpoch.VerificationStream.Reset(); + this.nextEpoch.VerificationStream.AddData( + packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length) + ); + + this.nextEpoch.State = HandshakeState.ExpectingServerHello; + } + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, plaintextLength), + ref outgoingRecord + ); + + if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; + + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + protected void Test_SendClientHello(Func<ClientHello, ByteSpan, ByteSpan> encodeCallback) + { + // Reset our verification stream + this.nextEpoch.VerificationStream.Reset(); + + // Describe our ClientHello flight + ClientHello clientHello = new ClientHello(); + clientHello.Random = this.nextEpoch.ClientRandom; + clientHello.Cookie = this.nextEpoch.Cookie; + clientHello.CipherSuites = new byte[2]; + clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); + clientHello.SupportedCurves = new byte[2]; + clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); + + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.ClientHello; + handshake.Length = (uint)clientHello.CalculateSize(); + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + // Describe the record + int plaintextLength = (int)(Handshake.Size + handshake.Length); + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.Handshake; + outgoingRecord.ProtocolVersion = DtlsVersion; + outgoingRecord.Epoch = this.epoch; + outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); + ++this.currentEpoch.NextOutgoingSequence; + + // Convert the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(packet); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + writer = encodeCallback(clientHello, writer); + + // Write ClientHello to the verification stream + this.nextEpoch.VerificationStream.AddData( + packet.Slice( + Record.Size + , Handshake.Size + (int)handshake.Length + ) + ); + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, plaintextLength), + ref outgoingRecord + ); + + this.nextEpoch.State = HandshakeState.ExpectingServerHello; + if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + /// <summary> + /// Send (resend) the ClientKeyExchange flight + /// </summary> + /// <param name="isRetransmit"> + /// True if this is a retransmit of the flight. Otherwise, + /// false + /// </param> + protected virtual void SendClientKeyExchangeFlight(bool isRetransmit) + { +#if DEBUG + var verb = isRetransmit ? "Resending" : "Sending"; + this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}"); +#endif + if (this.nextEpoch.State == HandshakeState.Established) + { + return; + } + + // Describe our flight + Handshake keyExchangeHandshake = new Handshake(); + keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange; + keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize(); + keyExchangeHandshake.MessageSequence = 5; + keyExchangeHandshake.FragmentOffset = 0; + keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length; + + Record keyExchangeRecord = new Record(); + keyExchangeRecord.ContentType = ContentType.Handshake; + keyExchangeRecord.ProtocolVersion = DtlsVersion; + keyExchangeRecord.Epoch = this.epoch; + keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length); + ++this.currentEpoch.NextOutgoingSequence; + + Record changeCipherSpecRecord = new Record(); + changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec; + changeCipherSpecRecord.ProtocolVersion = DtlsVersion; + changeCipherSpecRecord.Epoch = this.epoch; + changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size); + ++this.currentEpoch.NextOutgoingSequence; + + Handshake finishedHandshake = new Handshake(); + finishedHandshake.MessageType = HandshakeType.Finished; + finishedHandshake.Length = Finished.Size; + finishedHandshake.MessageSequence = 6; + finishedHandshake.FragmentOffset = 0; + finishedHandshake.FragmentLength = finishedHandshake.Length; + + Record finishedRecord = new Record(); + finishedRecord.ContentType = ContentType.Handshake; + finishedRecord.ProtocolVersion = DtlsVersion; + finishedRecord.Epoch = this.nextEpoch.Epoch; + finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence; + finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length); + ++this.nextEpoch.NextOutgoingSequence; + + // Encode flight to wire format + int packetLength = 0 + + Record.Size + keyExchangeRecord.Length + + Record.Size + changeCipherSpecRecord.Length + + Record.Size + finishedRecord.Length; + ; + ByteSpan packet = new byte[packetLength]; + ByteSpan writer = packet; + + keyExchangeRecord.Encode(writer); + writer = writer.Slice(Record.Size); + keyExchangeHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer); + + ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length); + writer = startOfChangeCipherSpecRecord; + changeCipherSpecRecord.Encode(writer); + writer = writer.Slice(Record.Size); + ChangeCipherSpec.Encode(writer); + writer = writer.Slice(ChangeCipherSpec.Size); + + ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length); + writer = startOfFinishedRecord; + finishedRecord.Encode(writer); + writer = writer.Slice(Record.Size); + finishedHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + // Interject here to writer our client key exchange + // message into the verification stream + if (!isRetransmit) + { + this.nextEpoch.VerificationStream.AddData( + packet.Slice( + Record.Size + , Handshake.Size + (int)keyExchangeHandshake.Length + ) + ); + } + + // Calculate the hash of the verification stream + ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize]; + this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash); + + // Expand our master secret into Finished digests for the client and server + PrfSha256.ExpandSecret( + this.nextEpoch.ServerVerification + , this.nextEpoch.MasterSecret + , PrfLabel.SERVER_FINISHED + , handshakeHash + ); + + PrfSha256.ExpandSecret( + writer.Slice(0, Finished.Size) + , this.nextEpoch.MasterSecret + , PrfLabel.CLIENT_FINISHED + , handshakeHash + ); + writer = writer.Slice(Finished.Size); + + // Protect the ClientKeyExchange record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, keyExchangeRecord.Length), + packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length), + ref keyExchangeRecord + ); + + // Protect the ChangeCipherSpec record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length), + startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size), + ref changeCipherSpecRecord + ); + + // Protect the Finished record + this.nextEpoch.RecordProtection.EncryptClientPlaintext( + startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length), + startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length), + ref finishedRecord + ); + + this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; +#if DEBUG + if (DropClientKeyExchangeFlight()) + { + return; + } +#endif + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + protected virtual bool DropClientKeyExchangeFlight() + { + return false; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs new file mode 100644 index 0000000..f840053 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs @@ -0,0 +1,734 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace Hazel.Dtls +{ + /// <summary> + /// Handshake message type + /// </summary> + public enum HandshakeType : byte + { + HelloRequest = 0, + ClientHello = 1, + ServerHello = 2, + HelloVerifyRequest = 3, + Certificate = 11, + ServerKeyExchange = 12, + CertificateRequest = 13, + ServerHelloDone = 14, + CertificateVerify = 15, + ClientKeyExchange = 16, + Finished = 20, + } + + /// <summary> + /// List of cipher suites + /// </summary> + public enum CipherSuite + { + TLS_NULL_WITH_NULL_NULL = 0x0000, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F, + } + + /// <summary> + /// List of compression methods + /// </summary> + public enum CompressionMethod : byte + { + Null = 0, + } + + /// <summary> + /// Extension type + /// </summary> + public enum ExtensionType : ushort + { + EllipticCurves = 10, + } + + /// <summary> + /// Named curves + /// </summary> + public enum NamedCurve : ushort + { + Reserved = 0, + secp256r1 = 23, + x25519 = 29, + } + + /// <summary> + /// Elliptic curve type + /// </summary> + public enum ECCurveType : byte + { + NamedCurve = 3, + } + + /// <summary> + /// Hash algorithms + /// </summary> + public enum HashAlgorithm : byte + { + None = 0, + Sha256 = 4, + } + + /// <summary> + /// Signature algorithms + /// </summary> + public enum SignatureAlgorithm : byte + { + Anonymous = 0, + RSA = 1, + ECDSA = 3, + } + + /// <summary> + /// Random state for entropy + /// </summary> + public struct Random + { + public const int Size = 0 + + 4 // gmt_unix_time + + 28 // random_bytes + ; + } + + /// <summary> + /// Encode/decode handshake protocol header + /// </summary> + public struct Handshake + { + public HandshakeType MessageType; + public uint Length; + public ushort MessageSequence; + public uint FragmentOffset; + public uint FragmentLength; + + public const int Size = 12; + + /// <summary> + /// Parse a Handshake protocol header from wire format + /// </summary> + /// <returns>True if we successfully decode a handshake header. Otherwise false</returns> + public static bool Parse(out Handshake header, ByteSpan span) + { + header = new Handshake(); + + if (span.Length < Size) + { + return false; + } + + header.MessageType = (HandshakeType)span[0]; + header.Length = span.ReadBigEndian24(1); + header.MessageSequence = span.ReadBigEndian16(4); + header.FragmentOffset = span.ReadBigEndian24(6); + header.FragmentLength = span.ReadBigEndian24(9); + return true; + } + + /// <summary> + /// Encode the Handshake protocol header to wire format + /// </summary> + /// <param name="span"></param> + public void Encode(ByteSpan span) + { + span[0] = (byte)this.MessageType; + span.WriteBigEndian24(this.Length, 1); + span.WriteBigEndian16(this.MessageSequence, 4); + span.WriteBigEndian24(this.FragmentOffset, 6); + span.WriteBigEndian24(this.FragmentLength, 9); + } + } + + /// <summary> + /// Encode/decode ClientHello Handshake message + /// </summary> + public struct ClientHello + { + public ProtocolVersion ClientProtocolVersion; + public ByteSpan Random; + public ByteSpan Cookie; + public HazelDtlsSessionInfo Session; + public ByteSpan CipherSuites; + public ByteSpan SupportedCurves; + + public const int MinSize = 0 + + 2 // client_version + + Dtls.Random.Size // random + + 1 // session_id (size) + + 1 // cookie (size) + + 2 // cipher_suites (size) + + 1 // compression_methods (size) + + 1 // compression_method[0] (NULL) + + + 2 // extensions size + + + 0 // NamedCurveList extensions[0] + + 2 // extensions[0].extension_type + + 2 // extensions[0].extension_data (length) + + 2 // extensions[0].named_curve_list (size) + ; + + /// <summary> + /// Calculate the size in bytes required for the ClientHello payload + /// </summary> + /// <returns></returns> + public int CalculateSize() + { + return MinSize + + this.Session.PayloadSize + + this.Cookie.Length + + this.CipherSuites.Length + + this.SupportedCurves.Length + ; + } + + /// <summary> + /// Parse a Handshake ClientHello payload from wire format + /// </summary> + /// <returns>True if we successfully decode the ClientHello message. Otherwise false</returns> + public static bool Parse(out ClientHello result, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + result = new ClientHello(); + if (span.Length < MinSize) + { + return false; + } + + result.ClientProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(); + if (expectedProtocolVersion.HasValue && result.ClientProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + span = span.Slice(2); + + result.Random = span.Slice(0, Dtls.Random.Size); + span = span.Slice(Dtls.Random.Size); + + if (!HazelDtlsSessionInfo.Parse(out result.Session, span)) + { + return false; + } + + span = span.Slice(result.Session.FullSize); + + byte cookieSize = span[0]; + if (span.Length < 1 + cookieSize) + { + return false; + } + result.Cookie = span.Slice(1, cookieSize); + span = span.Slice(1 + cookieSize); + + ushort cipherSuiteSize = span.ReadBigEndian16(); + if (span.Length < 2 + cipherSuiteSize) + { + return false; + } + else if (cipherSuiteSize % 2 != 0) + { + return false; + } + result.CipherSuites = span.Slice(2, cipherSuiteSize); + span = span.Slice(2 + cipherSuiteSize); + + int compressionMethodsSize = span[0]; + bool foundNullCompressionMethod = false; + for (int ii = 0; ii != compressionMethodsSize; ++ii) + { + if (span[1+ii] == (byte)CompressionMethod.Null) + { + foundNullCompressionMethod = true; + break; + } + } + + if (!foundNullCompressionMethod + || span.Length < 1 + compressionMethodsSize) + { + return false; + } + + span = span.Slice(1 + compressionMethodsSize); + + // Parse extensions + if (span.Length > 0) + { + if (span.Length < 2) + { + return false; + } + + ushort extensionsSize = span.ReadBigEndian16(); + span = span.Slice(2); + if (span.Length != extensionsSize) + { + return false; + } + + while (span.Length > 0) + { + // Parse extension header + if (span.Length < 4) + { + return false; + } + + ExtensionType extensionType = (ExtensionType)span.ReadBigEndian16(0); + ushort extensionLength = span.ReadBigEndian16(2); + + if (span.Length < 4 + extensionLength) + { + return false; + } + + ByteSpan extensionData = span.Slice(4, extensionLength); + span = span.Slice(4 + extensionLength); + result.ParseExtension(extensionType, extensionData); + } + } + + return true; + } + + /// <summary> + /// Decode a ClientHello extension + /// </summary> + /// <param name="extensionType">Extension type</param> + /// <param name="extensionData">Extension data</param> + private void ParseExtension(ExtensionType extensionType, ByteSpan extensionData) + { + switch (extensionType) + { + case ExtensionType.EllipticCurves: + if (extensionData.Length % 2 != 0) + { + break; + } + else if (extensionData.Length < 2) + { + break; + } + + ushort namedCurveSize = extensionData.ReadBigEndian16(0); + if (namedCurveSize % 2 != 0) + { + break; + } + + this.SupportedCurves = extensionData.Slice(2, namedCurveSize); + break; + } + } + + /// <summary> + /// Determines if the ClientHello message advertises support + /// for the specified cipher suite + /// </summary> + public bool ContainsCipherSuite(CipherSuite cipherSuite) + { + ByteSpan iterator = this.CipherSuites; + while (iterator.Length >= 2) + { + if (iterator.ReadBigEndian16() == (ushort)cipherSuite) + { + return true; + } + + iterator = iterator.Slice(2); + } + + return false; + } + + /// <summary> + /// Determines if the ClientHello message advertises support + /// for the specified curve + /// </summary> + public bool ContainsCurve(NamedCurve curve) + { + ByteSpan iterator = this.SupportedCurves; + while (iterator.Length >= 2) + { + if (iterator.ReadBigEndian16() == (ushort)curve) + { + return true; + } + + iterator = iterator.Slice(2); + } + + return false; + } + + /// <summary> + /// Encode Handshake ClientHello payload to wire format + /// </summary> + public void Encode(ByteSpan span) + { + span.WriteBigEndian16((ushort)this.ClientProtocolVersion); + span = span.Slice(2); + + Debug.Assert(this.Random.Length == Dtls.Random.Size); + this.Random.CopyTo(span); + span = span.Slice(Dtls.Random.Size); + + this.Session.Encode(span); + span = span.Slice(this.Session.FullSize); + + span[0] = (byte)this.Cookie.Length; + this.Cookie.CopyTo(span.Slice(1)); + span = span.Slice(1 + this.Cookie.Length); + + span.WriteBigEndian16((ushort)this.CipherSuites.Length); + this.CipherSuites.CopyTo(span.Slice(2)); + span = span.Slice(2 + this.CipherSuites.Length); + + span[0] = 1; + span[1] = (byte)CompressionMethod.Null; + span = span.Slice(2); + + // Extensions size + span.WriteBigEndian16((ushort)(6 + this.SupportedCurves.Length)); + span = span.Slice(2); + + // Supported curves extension + span.WriteBigEndian16((ushort)ExtensionType.EllipticCurves); + span.WriteBigEndian16((ushort)(2 + this.SupportedCurves.Length), 2); + span.WriteBigEndian16((ushort)this.SupportedCurves.Length, 4); + this.SupportedCurves.CopyTo(span.Slice(6)); + } + } + + /// <summary> + /// Encode/Decode session information in ClientHello + /// </summary> + public struct HazelDtlsSessionInfo + { + public const byte CurrentClientSessionSize = 1; + public const byte CurrentClientSessionVersion = 1; + + public byte FullSize => (byte)(1 + this.PayloadSize); + public byte PayloadSize; + public byte Version; + + public HazelDtlsSessionInfo(byte version) + { + this.Version = version; + switch (version) + { + case 0: // Does not write version byte + this.PayloadSize = 0; + return; + case 1: // Writes version byte only + this.PayloadSize = 1; + return; + } + + throw new ArgumentOutOfRangeException("Unimplemented Hazel session version"); + } + + public void Encode(ByteSpan writer) + { + writer[0] = this.PayloadSize; + + if (this.Version > 0) + { + writer[1] = this.Version; + } + } + + public static bool Parse(out HazelDtlsSessionInfo result, ByteSpan reader) + { + result = new HazelDtlsSessionInfo(); + if (reader.Length < 1) + { + return false; + } + + result.PayloadSize = reader[0]; + + // Back compat, length may be zero, version defaults to 0. + if (result.PayloadSize == 0) + { + result.Version = 0; + return true; + } + + // Forward compat, if length > 1, ignore the rest + result.Version = reader[1]; + return true; + } + } + + /// <summary> + /// Encode/decode Handshake HelloVerifyRequest message + /// </summary> + public struct HelloVerifyRequest + { + public const int CookieSize = 20; + public const int Size = 0 + + 2 // server_version + + 1 // cookie (size) + + CookieSize // cookie (data) + ; + + public ProtocolVersion ServerProtocolVersion; + public ByteSpan Cookie; + + /// <summary> + /// Parse a Handshake HelloVerifyRequest payload from wire + /// format + /// </summary> + /// <returns> + /// True if we successfully decode the HelloVerifyRequest + /// message. Otherwise false. + /// </returns> + public static bool Parse(out HelloVerifyRequest result, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + result = new HelloVerifyRequest(); + if (span.Length < 3) + { + return false; + } + + result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(0); + if (expectedProtocolVersion.HasValue && result.ServerProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + byte cookieSize = span[2]; + span = span.Slice(3); + + if (span.Length < cookieSize) + { + return false; + } + + result.Cookie = span; + return true; + } + + /// <summary> + /// Encode a HelloVerifyRequest payload to wire format + /// </summary> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + public static void Encode(ByteSpan span, EndPoint peerAddress, HMAC hmac, ProtocolVersion protocolVersion) + { + ByteSpan cookie = ComputeAddressMac(peerAddress, hmac); + + span.WriteBigEndian16((ushort)protocolVersion); + span[2] = (byte)CookieSize; + cookie.CopyTo(span.Slice(3)); + } + + /// <summary> + /// Generate an HMAC for a peer address + /// </summary> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + public static ByteSpan ComputeAddressMac(EndPoint peerAddress, HMAC hmac) + { + SocketAddress address = peerAddress.Serialize(); + byte[] data = new byte[address.Size]; + for (int ii = 0, nn = data.Length; ii != nn; ++ii) + { + data[ii] = address[ii]; + } + + ///NOTE(mendsley): Lame that we need to allocate+copy here + ByteSpan signature = hmac.ComputeHash(data); + return signature.Slice(0, CookieSize); + } + + /// <summary> + /// Verify a client's cookie was signed by our listener + /// </summary> + /// <param name="cookie">Wire format cookie</param> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + /// <returns>True if the cookie is valid. Otherwise false</returns> + public static bool VerifyCookie(ByteSpan cookie, EndPoint peerAddress, HMAC hmac) + { + if (cookie.Length != CookieSize) + { + return false; + } + + ByteSpan expectedHash = ComputeAddressMac(peerAddress, hmac); + if (expectedHash.Length != cookie.Length) + { + return false; + } + + return (1 == Crypto.Const.ConstantCompareSpans(cookie, expectedHash)); + } + } + + /// <summary> + /// Encode/decode Handshake ServerHello message + /// </summary> + public struct ServerHello + { + public ProtocolVersion ServerProtocolVersion; + public ByteSpan Random; + public CipherSuite CipherSuite; + public HazelDtlsSessionInfo Session; + + public const int MinSize = 0 + + 2 // server_version + + Dtls.Random.Size // random + + 1 // session_id (size) + + 2 // cipher_suite + + 1 // compression_method + ; + + public int Size => MinSize + Session.PayloadSize; + + /// <summary> + /// Parse a Handshake ServerHello payload from wire format + /// </summary> + /// <returns> + /// True if we successfully decode the ServerHello + /// message. Otherwise false. + /// </returns> + public static bool Parse(out ServerHello result, ByteSpan span) + { + result = new ServerHello(); + if (span.Length < MinSize) + { + return false; + } + + result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(); + span = span.Slice(2); + + result.Random = span.Slice(0, Dtls.Random.Size); + span = span.Slice(Dtls.Random.Size); + + if (!HazelDtlsSessionInfo.Parse(out result.Session, span)) + { + return false; + } + + span = span.Slice(result.Session.FullSize); + + result.CipherSuite = (CipherSuite)span.ReadBigEndian16(); + span = span.Slice(2); + + CompressionMethod compressionMethod = (CompressionMethod)span[0]; + if (compressionMethod != CompressionMethod.Null) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode Handshake ServerHello to wire format + /// </summary> + public void Encode(ByteSpan span) + { + Debug.Assert(this.Random.Length == Dtls.Random.Size); + + span.WriteBigEndian16((ushort)this.ServerProtocolVersion, 0); + span = span.Slice(2); + + this.Random.CopyTo(span); + span = span.Slice(Dtls.Random.Size); + + this.Session.Encode(span); + span = span.Slice(this.Session.FullSize); + + span.WriteBigEndian16((ushort)this.CipherSuite); + span = span.Slice(2); + + span[0] = (byte)CompressionMethod.Null; + } + } + + /// <summary> + /// Encode/decode Handshake Certificate message + /// </summary> + public struct Certificate + { + /// <summary> + /// Encode a certificate to wire formate + /// </summary> + public static ByteSpan Encode(X509Certificate2 certificate) + { + ByteSpan certData = certificate.GetRawCertData(); + int totalSize = certData.Length + 3 + 3; + + ByteSpan result = new byte[totalSize]; + + ByteSpan writer = result; + writer.WriteBigEndian24((uint)certData.Length + 3); + writer = writer.Slice(3); + writer.WriteBigEndian24((uint)certData.Length); + writer = writer.Slice(3); + + certData.CopyTo(writer); + return result; + } + + /// <summary> + /// Parse a Handshake Certificate payload from wire format + /// </summary> + /// <returns>True if we successfully decode the Certificate message. Otherwise false</returns> + public static bool Parse(out X509Certificate2 certificate, ByteSpan span) + { + certificate = null; + if (span.Length < 6) + { + return false; + } + + uint totalSize = span.ReadBigEndian24(); + span = span.Slice(3); + + if (span.Length < totalSize) + { + return false; + } + + uint certificateSize = span.ReadBigEndian24(); + span = span.Slice(3); + if (span.Length < certificateSize) + { + return false; + } + + byte[] rawData = new byte[certificateSize]; + span.CopyTo(rawData, 0); + try + { + certificate = new X509Certificate2(rawData); + } + catch (Exception) + { + return false; + } + + return true; + } + } + + /// <summary> + /// Encode/decode Handshake Finished message + /// </summary> + public struct Finished + { + public const int Size = 12; + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs new file mode 100644 index 0000000..eedd977 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs @@ -0,0 +1,63 @@ +using System; + +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS cipher suite interface for the handshake portion of + /// the connection. + /// </summary> + public interface IHandshakeCipherSuite : IDisposable + { + /// <summary> + /// Gets the size of the shared key + /// </summary> + /// <returns>Size of the shared key in bytes </returns> + int SharedKeySize(); + + /// <summary> + /// Calculate the size of the ServerKeyExchnage message + /// </summary> + /// <param name="privateKey"> + /// Private key that will be used to sign the message + /// </param> + /// <returns>Size of the message in bytes</returns> + int CalculateServerMessageSize(object privateKey); + + /// <summary> + /// Encodes the ServerKeyExchange message + /// </summary> + /// <param name="privateKey">Private key to use for signing</param> + void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey); + + /// <summary> + /// Verifies the authenticity of a server key exchange + /// message and calculates the shared secret. + /// </summary> + /// <returns> + /// True if the authenticity has been validated and a shared key + /// was generated. Otherwise, false. + /// </returns> + bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey); + + /// <summary> + /// Calculate the size of the ClientKeyExchange message + /// </summary> + /// <returns>Size of the message in bytes</returns> + int CalculateClientMessageSize(); + + /// <summary> + /// Encodes the ClientKeyExchangeMessage + /// </summary> + void EncodeClientKeyExchangeMessage(ByteSpan output); + + /// <summary> + /// Verifies the validity of a client key exchange message + /// and calculats the hsared secret. + /// </summary> + /// <returns> + /// True if the client exchange message is valid and a + /// shared key was generated. Otherwise, false. + /// </returns> + bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage); + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs new file mode 100644 index 0000000..cbee1b0 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs @@ -0,0 +1,84 @@ +using System; + +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS cipher suite interface for protection of record payload. + /// </summary> + public interface IRecordProtection : IDisposable + { + /// <summary> + /// Calculate the size of an encrypted plaintext + /// </summary> + /// <param name="dataSize">Size of plaintext in bytes</param> + /// <returns>Size of encrypted ciphertext in bytes</returns> + int GetEncryptedSize(int dataSize); + + /// <summary> + /// Calculate the size of decrypted ciphertext + /// </summary> + /// <param name="dataSize">Size of ciphertext in bytes</param> + /// <returns>Size of decrypted plaintext in bytes</returns> + int GetDecryptedSize(int dataSize); + + /// <summary> + /// Encrypt a plaintext intput with server keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output ciphertext</param> + /// <param name="input">Input plaintext</param> + /// <param name="record">Parent DTLS record</param> + void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Encrypt a plaintext intput with client keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output ciphertext</param> + /// <param name="input">Input plaintext</param> + /// <param name="record">Parent DTLS record</param> + void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Decrypt a ciphertext intput with server keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output plaintext</param> + /// <param name="input">Input ciphertext</param> + /// <param name="record">Parent DTLS record</param> + /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns> + bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Decrypt a ciphertext intput with client keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output plaintext</param> + /// <param name="input">Input ciphertext</param> + /// <param name="record">Parent DTLS record</param> + /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns> + bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record); + } + + /// <summary> + /// Factory to create record protection from cipher suite identifiers + /// </summary> + public sealed class RecordProtectionFactory + { + public static IRecordProtection Create(CipherSuite cipherSuite, ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom) + { + switch (cipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return new Aes128GcmRecordProtection(masterSecret, serverRandom, clientRandom); + + default: + return null; + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs new file mode 100644 index 0000000..76fa132 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Hazel.Dtls +{ + /// <summary> + /// Passthrough record protection implementaion + /// </summary> + public class NullRecordProtection : IRecordProtection + { + public readonly static NullRecordProtection Instance = new NullRecordProtection(); + + public void Dispose() + { + } + + public int GetEncryptedSize(int dataSize) + { + return dataSize; + } + + public int GetDecryptedSize(int dataSize) + { + return dataSize; + } + + public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + } + + public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + } + + public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + return true; + } + + public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + return true; + } + + private static void CopyMaybeOverlappingSpans(ByteSpan output, ByteSpan input) + { + // Early out if the ranges `output` is equal to `input` + if (output.GetUnderlyingArray() == input.GetUnderlyingArray()) + { + if (output.Offset == input.Offset && output.Length == input.Length) + { + return; + } + } + + input.CopyTo(output); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs new file mode 100644 index 0000000..1fa7f17 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs @@ -0,0 +1,84 @@ +using System.Text; +using System.Security.Cryptography; + +namespace Hazel.Dtls +{ + /// <summary> + /// Common Psuedorandom Function labels for TLS + /// </summary> + public struct PrfLabel + { + public static readonly ByteSpan MASTER_SECRET = LabelToBytes("master secert"); + public static readonly ByteSpan KEY_EXPANSION = LabelToBytes("key expansion"); + public static readonly ByteSpan CLIENT_FINISHED = LabelToBytes("client finished"); + public static readonly ByteSpan SERVER_FINISHED = LabelToBytes("server finished"); + + /// <summary> + /// Convert a text label to a byte sequence + /// </summary> + public static ByteSpan LabelToBytes(string label) + { + return Encoding.ASCII.GetBytes(label); + } + } + + /// <summary> + /// The P_SHA256 Psuedorandom Function + /// </summary> + public struct PrfSha256 + { + /// <summary> + /// Expand a secret key + /// </summary> + /// <param name="output">Output span. Length determines how much data to generate</param> + /// <param name="key">Original key to expand</param> + /// <param name="label">Label (treated as a salt)</param> + /// <param name="initialSeed">Seed for expansion (treated as a salt)</param> + public static void ExpandSecret(ByteSpan output, ByteSpan key, string label, ByteSpan initialSeed) + { + ExpandSecret(output, key, PrfLabel.LabelToBytes(label), initialSeed); + } + + /// <summary> + /// Expand a secret key + /// </summary> + /// <param name="output">Output span. Length determines how much data to generate</param> + /// <param name="key">Original key to expand</param> + /// <param name="label">Label (treated as a salt)</param> + /// <param name="initialSeed">Seed for expansion (treated as a salt)</param> + public static void ExpandSecret(ByteSpan output, ByteSpan key, ByteSpan label, ByteSpan initialSeed) + { + ByteSpan writer = output; + + byte[] roundSeed = new byte[label.Length + initialSeed.Length]; + label.CopyTo(roundSeed); + initialSeed.CopyTo(roundSeed, label.Length); + + byte[] hashA = roundSeed; + + using (HMACSHA256 hmac = new HMACSHA256(key.ToArray())) + { + byte[] input = new byte[hmac.OutputBlockSize + roundSeed.Length]; + new ByteSpan(roundSeed).CopyTo(input, hmac.OutputBlockSize); + + while (writer.Length > 0) + { + // Update hashA + hashA = hmac.ComputeHash(hashA); + + // generate hash input + new ByteSpan(hashA).CopyTo(input); + + ByteSpan roundOutput = hmac.ComputeHash(input); + if (roundOutput.Length > writer.Length) + { + roundOutput = roundOutput.Slice(0, writer.Length); + } + + roundOutput.CopyTo(writer); + writer = writer.Slice(roundOutput.Length); + } + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Record.cs b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs new file mode 100644 index 0000000..23eaa95 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs @@ -0,0 +1,123 @@ +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS version constants + /// </summary> + public enum ProtocolVersion : ushort + { + /// <summary> + /// Use to obfuscate DTLS as regular UDP packets + /// </summary> + UDP = 0, + + /// <summary> + /// DTLS 1.2 + /// </summary> + DTLS1_2 = 0xFEFD, + } + + /// <summary> + /// DTLS record content type + /// </summary> + public enum ContentType : byte + { + ChangeCipherSpec = 20, + Alert = 21, + Handshake = 22, + ApplicationData = 23, + } + + /// <summary> + /// Encode/decode DTLS record header + /// </summary> + public struct Record + { + public ContentType ContentType; + public ProtocolVersion ProtocolVersion; + public ushort Epoch; + public ulong SequenceNumber; + public ushort Length; + + public const int Size = 13; + + /// <summary> + /// Parse a DTLS record from wire format + /// </summary> + /// <returns>True if we successfully parse the record header. Otherwise false</returns> + public static bool Parse(out Record record, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + record = new Record(); + + if (span.Length < Size) + { + return false; + } + + record.ContentType = (ContentType)span[0]; + record.ProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(1); + record.Epoch = span.ReadBigEndian16(3); + record.SequenceNumber = span.ReadBigEndian48(5); + record.Length = span.ReadBigEndian16(11); + + if (expectedProtocolVersion.HasValue && record.ProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode a DTLS record to wire format + /// </summary> + public void Encode(ByteSpan span) + { + span[0] = (byte)this.ContentType; + span.WriteBigEndian16((ushort)this.ProtocolVersion, 1); + span.WriteBigEndian16(this.Epoch, 3); + span.WriteBigEndian48(this.SequenceNumber, 5); + span.WriteBigEndian16(this.Length, 11); + } + } + + public struct ChangeCipherSpec + { + public const int Size = 1; + + enum Value : byte + { + ChangeCipherSpec = 1, + } + + /// <summary> + /// Parse a ChangeCipherSpec record from wire format + /// </summary> + /// <returns> + /// True if we successfully parse the ChangeCipherSpec + /// record. Otherwise, false. + /// </returns> + public static bool Parse(ByteSpan span) + { + if (span.Length != 1) + { + return false; + } + + Value value = (Value)span[0]; + if (value != Value.ChangeCipherSpec) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode a ChangeCipherSpec record to wire format + /// </summary> + public static void Encode(ByteSpan span) + { + span[0] = (byte)Value.ChangeCipherSpec; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs new file mode 100644 index 0000000..38da061 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs @@ -0,0 +1,159 @@ +using System; +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Threading; + +namespace Hazel.Dtls +{ + internal class ThreadedHmacHelper : IDisposable + { + private class ThreadHmacs + { + public HMAC currentHmac; + public HMAC previousHmac; + public HMAC hmacToDispose; + } + + private static readonly int CookieHmacRotationTimeout = (int)TimeSpan.FromHours(1.0).TotalMilliseconds; + + private readonly ILogger logger; + private readonly ConcurrentDictionary<int, ThreadHmacs> hmacs; + private Timer rotateKeyTimer; + private RandomNumberGenerator cryptoRandom; + private byte[] currentHmacKey; + + public ThreadedHmacHelper(ILogger logger) + { + this.hmacs = new ConcurrentDictionary<int, ThreadHmacs>(); + this.rotateKeyTimer = new Timer(RotateKeys, null, CookieHmacRotationTimeout, CookieHmacRotationTimeout); + this.cryptoRandom = RandomNumberGenerator.Create(); + + this.logger = logger; + SetHmacKey(); + } + + /// <summary> + /// [ThreadSafe] Get the current cookie hmac for the current thread. + /// </summary> + public HMAC GetCurrentCookieHmacsForThread() + { + return GetHmacsForThread().currentHmac; + } + + /// <summary> + /// [ThreadSafe] Get the previous cookie hmac for the current thread. + /// </summary> + public HMAC GetPreviousCookieHmacsForThread() + { + return GetHmacsForThread().previousHmac; + } + + public void Dispose() + { + ManualResetEvent signalRotateKeyTimerEnded = new ManualResetEvent(false); + this.rotateKeyTimer.Dispose(signalRotateKeyTimerEnded); + signalRotateKeyTimerEnded.WaitOne(); + signalRotateKeyTimerEnded.Dispose(); + signalRotateKeyTimerEnded = null; + this.rotateKeyTimer = null; + + this.cryptoRandom.Dispose(); + this.cryptoRandom = null; + + foreach (var threadIdToHmac in this.hmacs) + { + ThreadHmacs threadHmacs = threadIdToHmac.Value; + threadHmacs.currentHmac?.Dispose(); + threadHmacs.currentHmac = null; + threadHmacs.previousHmac?.Dispose(); + threadHmacs.previousHmac = null; + threadHmacs.hmacToDispose?.Dispose(); + threadHmacs.hmacToDispose = null; + } + + this.hmacs.Clear(); + } + + private ThreadHmacs GetHmacsForThread() + { + int threadId = Thread.CurrentThread.ManagedThreadId; + + if (!this.hmacs.TryGetValue(threadId, out ThreadHmacs threadHmacs)) + { + threadHmacs = CreateNewThreadHmacs(); + + if (!this.hmacs.TryAdd(threadId, threadHmacs)) + { + this.logger.WriteError($"Cannot add threadHmacs for thread {threadId} during GetHmacsForThread! Should never happen!"); + } + } + + return threadHmacs; + } + + /// <summary> + /// Rotates the hmacs of all active threads + /// </summary> + private void RotateKeys(object _) + { + SetHmacKey(); + + foreach (var threadIds in this.hmacs) + { + RotateKey(threadIds.Key); + } + } + + /// <summary> + /// Rotate hmacs of single thread + /// </summary> + /// <param name="threadId">Managed thread Id of thread calling this method.</param> + private void RotateKey(int threadId) + { + ThreadHmacs threadHmacs; + + if (!this.hmacs.TryGetValue(threadId, out threadHmacs)) + { + this.logger.WriteError($"Cannot find thread {threadId} in hmacs during rotation! Should never happen!"); + return; + } + + // No thread should still have a reference to hmacToDispose, which should now have a lifetime of > 1 hour + threadHmacs.hmacToDispose?.Dispose(); + threadHmacs.hmacToDispose = threadHmacs.previousHmac; + threadHmacs.previousHmac = threadHmacs.currentHmac; + threadHmacs.currentHmac = CreateNewCookieHMAC(); + } + + private ThreadHmacs CreateNewThreadHmacs() + { + return new ThreadHmacs + { + previousHmac = CreateNewCookieHMAC(), + currentHmac = CreateNewCookieHMAC() + }; + } + + /// <summary> + /// Create a new cookie HMAC signer + /// </summary> + private HMAC CreateNewCookieHMAC() + { + const string HMACProvider = "System.Security.Cryptography.HMACSHA1"; + HMAC hmac = HMAC.Create(HMACProvider); + hmac.Key = this.currentHmacKey; + return hmac; + } + + /// <summary> + /// Creates a new cryptographically secure random Hmac key + /// </summary> + private void SetHmacKey() + { + // MSDN recommends 64 bytes key for HMACSHA-1 + byte[] newKey = new byte[64]; + this.cryptoRandom.GetBytes(newKey); + this.currentHmacKey = newKey; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs new file mode 100644 index 0000000..f567252 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs @@ -0,0 +1,202 @@ +using Hazel.Crypto; +using System; +using System.Diagnostics; +using System.Security.Cryptography; + +namespace Hazel.Dtls +{ + /// <summary> + /// ECDHE_RSA_*_256 cipher suite + /// </summary> + public class X25519EcdheRsaSha256 : IHandshakeCipherSuite + { + private readonly ByteSpan privateAgreementKey; + private SHA256 sha256 = SHA256.Create(); + + /// <summary> + /// Create a new instance of the x25519 key exchange + /// </summary> + /// <param name="random">Random data source</param> + public X25519EcdheRsaSha256(RandomNumberGenerator random) + { + byte[] buffer = new byte[X25519.KeySize]; + random.GetBytes(buffer); + this.privateAgreementKey = buffer; + } + + /// <inheritdoc /> + public void Dispose() + { + this.sha256?.Dispose(); + this.sha256 = null; + } + + /// <inheritdoc /> + public int SharedKeySize() + { + return X25519.KeySize; + } + + /// <summary> + /// Calculate the server message size given an RSA key size + /// </summary> + /// <param name="keySize"> + /// Size of the private key (in bits) + /// </param> + /// <returns> + /// Size of the ServerKeyExchange message in bytes + /// </returns> + private static int CalculateServerMessageSize(int keySize) + { + int signatureSize = keySize / 8; + + return 0 + + 1 // ECCurveType ServerKeyExchange.params.curve_params.curve_type + + 2 // NamedCurve ServerKeyExchange.params.curve_params.namedcurve + + 1 + X25519.KeySize // ECPoint ServerKeyExchange.params.public + + 1 // HashAlgorithm ServerKeyExchange.algorithm.hash + + 1 // SignatureAlgorithm ServerKeyExchange.signed_params.algorithm.signature + + 2 // ServerKeyExchange.signed_params.size + + signatureSize // ServerKeyExchange.signed_params.opaque + ; + } + + /// <inheritdoc /> + public int CalculateServerMessageSize(object privateKey) + { + RSA rsaPrivateKey = privateKey as RSA; + if (rsaPrivateKey == null) + { + throw new ArgumentException("Invalid private key", nameof(privateKey)); + } + + return CalculateServerMessageSize(rsaPrivateKey.KeySize); + } + + /// <inheritdoc /> + public void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey) + { + RSA rsaPrivateKey = privateKey as RSA; + if (rsaPrivateKey == null) + { + throw new ArgumentException("Invalid private key", nameof(privateKey)); + } + + output[0] = (byte)ECCurveType.NamedCurve; + output.WriteBigEndian16((ushort)NamedCurve.x25519, 1); + output[3] = (byte)X25519.KeySize; + X25519.Func(output.Slice(4, X25519.KeySize), this.privateAgreementKey); + + // Hash the key parameters + byte[] paramterDigest = this.sha256.ComputeHash(output.GetUnderlyingArray(), output.Offset, 4 + X25519.KeySize); + + // Sign the paramter digest + RSAPKCS1SignatureFormatter signer = new RSAPKCS1SignatureFormatter(rsaPrivateKey); + signer.SetHashAlgorithm("SHA256"); + ByteSpan signature = signer.CreateSignature(paramterDigest); + + Debug.Assert(signature.Length == rsaPrivateKey.KeySize/8); + output[4 + X25519.KeySize] = (byte)HashAlgorithm.Sha256; + output[5 + X25519.KeySize] = (byte)SignatureAlgorithm.RSA; + output.Slice(6+X25519.KeySize).WriteBigEndian16((ushort)signature.Length); + signature.CopyTo(output.Slice(8+X25519.KeySize)); + } + + /// <inheritdoc /> + public bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey) + { + RSA rsaPublicKey = publicKey as RSA; + if (rsaPublicKey == null) + { + return false; + } + else if (output.Length != X25519.KeySize) + { + return false; + } + + // Verify message is compatible with this cipher suite + if (serverKeyExchangeMessage.Length != CalculateServerMessageSize(rsaPublicKey.KeySize)) + { + return false; + } + else if (serverKeyExchangeMessage[0] != (byte)ECCurveType.NamedCurve) + { + return false; + } + else if (serverKeyExchangeMessage.ReadBigEndian16(1) != (ushort)NamedCurve.x25519) + { + return false; + } + else if (serverKeyExchangeMessage[3] != X25519.KeySize) + { + return false; + } + else if (serverKeyExchangeMessage[4 + X25519.KeySize] != (byte)HashAlgorithm.Sha256) + { + return false; + } + else if (serverKeyExchangeMessage[5 + X25519.KeySize] != (byte)SignatureAlgorithm.RSA) + { + return false; + } + + ByteSpan keyParameters = serverKeyExchangeMessage.Slice(0, 4+X25519.KeySize); + ByteSpan othersPublicKey = keyParameters.Slice(4); + ushort signatureSize = serverKeyExchangeMessage.ReadBigEndian16(6 + X25519.KeySize); + ByteSpan signature = serverKeyExchangeMessage.Slice(4+keyParameters.Length); + + if (signatureSize != signature.Length) + { + return false; + } + + // Hash the key parameters + byte[] parameterDigest = this.sha256.ComputeHash(keyParameters.GetUnderlyingArray(), keyParameters.Offset, keyParameters.Length); + + // Verify the signature + RSAPKCS1SignatureDeformatter verifier = new RSAPKCS1SignatureDeformatter(rsaPublicKey); + verifier.SetHashAlgorithm("SHA256"); + if (!verifier.VerifySignature(parameterDigest, signature.ToArray())) + { + return false; + } + + // Signature has been validated, generate the shared key + return X25519.Func(output, this.privateAgreementKey, othersPublicKey); + } + + private static int ClientMessageSize = 0 + + 1 + X25519.KeySize // ECPoint ClientKeyExchange.ecdh_Yc + ; + + /// <inheritdoc /> + public int CalculateClientMessageSize() + { + return ClientMessageSize; + } + + /// <inheritdoc /> + public void EncodeClientKeyExchangeMessage(ByteSpan output) + { + output[0] = (byte)X25519.KeySize; + X25519.Func(output.Slice(1, X25519.KeySize), this.privateAgreementKey); + } + + /// <inheritdoc /> + public bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage) + { + if (clientKeyExchangeMessage.Length != ClientMessageSize) + { + return false; + } + else if (clientKeyExchangeMessage[0] != (byte)X25519.KeySize) + { + return false; + } + + ByteSpan othersPublicKey = clientKeyExchangeMessage.Slice(1); + return X25519.Func(output, this.privateAgreementKey, othersPublicKey); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Extensions.cs b/Tools/Hazel-Networking/Hazel/Extensions.cs new file mode 100644 index 0000000..dd3a1bc --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Extensions.cs @@ -0,0 +1,34 @@ +using System.Collections.Generic; + +namespace Hazel +{ + public static class Extensions + { + public static void Swap<T>(this IList<T> self, int idx0, int idx1) + { + var temp = self[idx0]; + self[idx0] = self[idx1]; + self[idx1] = temp; + } + + public static int ClampToInt(this float value, int min, int max) + { + int output = (int)value; + if (output < min) output = min; + else if (output > max) output = max; + return output; + } + + public static bool TryDequeue<T>(this Queue<T> self, out T item) + { + if (self.Count > 0) + { + item = self.Dequeue(); + return true; + } + + item = default; + return false; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs new file mode 100644 index 0000000..fb36b00 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Hazel +{ + internal class HazelThreadPool + { + private Thread[] threads; + + public HazelThreadPool(int numThreads, ThreadStart action) + { + this.threads = new Thread[numThreads]; + for (int i = 0; i < this.threads.Length; ++i) + { + this.threads[i] = new Thread(action); + } + } + + public void Start() + { + for (int i = 0; i < this.threads.Length; ++i) + { + this.threads[i].Start(); + } + } + + public void Join() + { + for (int i = 0; i < this.threads.Length; ++i) + { + var thread = this.threads[i]; + try + { + thread.Join(); + } + catch { } + } + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs new file mode 100644 index 0000000..e37be45 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs @@ -0,0 +1,402 @@ +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +namespace Hazel.Udp.FewerThreads +{ + /// <summary> + /// Listens for new UDP connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class ThreadLimitedUdpConnectionListener : NetworkConnectionListener + { + private struct SendMessageInfo + { + public ByteSpan Span; + public IPEndPoint Recipient; + } + + private struct ReceiveMessageInfo + { + public MessageReader Message; + public IPEndPoint Sender; + public ConnectionId ConnectionId; + } + + private const int SendReceiveBufferSize = 1024 * 1024; + + private Socket socket; + protected ILogger Logger; + + private Thread reliablePacketThread; + private Thread receiveThread; + private Thread sendThread; + private HazelThreadPool processThreads; + + public bool ReceiveThreadRunning => this.receiveThread.ThreadState == ThreadState.Running; + + public struct ConnectionId : IEquatable<ConnectionId> + { + public IPEndPoint EndPoint; + public int Serial; + + public static ConnectionId Create(IPEndPoint endPoint, int serial) + { + return new ConnectionId{ + EndPoint = endPoint, + Serial = serial, + }; + } + + public bool Equals(ConnectionId other) + { + return this.Serial == other.Serial + && this.EndPoint.Equals(other.EndPoint) + ; + } + + public override bool Equals(object obj) + { + if (obj is ConnectionId) + { + return this.Equals((ConnectionId)obj); + } + + return false; + } + + public override int GetHashCode() + { + ///NOTE(mendsley): We're only hashing the endpoint + /// here, as the common case will have one + /// connection per address+port tuple. + return this.EndPoint.GetHashCode(); + } + } + + protected ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection> allConnections = new ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection>(); + + private BlockingCollection<ReceiveMessageInfo> receiveQueue; + private BlockingCollection<SendMessageInfo> sendQueue = new BlockingCollection<SendMessageInfo>(); + + public int MaxAge + { + get + { + var now = DateTime.UtcNow; + TimeSpan max = new TimeSpan(); + foreach (var con in allConnections.Values) + { + var val = now - con.CreationTime; + if (val > max) max = val; + } + + return (int)max.TotalSeconds; + } + } + + public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count; + public override int ConnectionCount { get { return this.allConnections.Count; } } + public override int SendQueueLength { get { return this.sendQueue.Count; } } + public override int ReceiveQueueLength { get { return this.receiveQueue.Count; } } + + private bool isActive; + + public ThreadLimitedUdpConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4) + { + this.Logger = logger; + this.EndPoint = endPoint; + this.IPMode = ipMode; + + this.receiveQueue = new BlockingCollection<ReceiveMessageInfo>(10000); + + this.socket = UdpConnection.CreateSocket(this.IPMode); + this.socket.ExclusiveAddressUse = true; + this.socket.Blocking = false; + + this.socket.ReceiveBufferSize = SendReceiveBufferSize; + this.socket.SendBufferSize = SendReceiveBufferSize; + + this.reliablePacketThread = new Thread(ManageReliablePackets); + this.sendThread = new Thread(SendLoop); + this.receiveThread = new Thread(ReceiveLoop); + this.processThreads = new HazelThreadPool(numWorkers, ProcessingLoop); + } + + ~ThreadLimitedUdpConnectionListener() + { + this.Dispose(false); + } + + // This is just for booting people after they've been connected a certain amount of time... + public void DisconnectOldConnections(TimeSpan maxAge, MessageWriter disconnectMessage) + { + var now = DateTime.UtcNow; + foreach (var conn in this.allConnections.Values) + { + if (now - conn.CreationTime > maxAge) + { + conn.Disconnect("Stale Connection", disconnectMessage); + } + } + } + + private void ManageReliablePackets() + { + while (this.isActive) + { + foreach (var kvp in this.allConnections) + { + var sock = kvp.Value; + sock.ManageReliablePackets(); + } + + Thread.Sleep(100); + } + } + + public override void Start() + { + try + { + socket.Bind(EndPoint); + } + catch (SocketException e) + { + throw new HazelException("Could not start listening as a SocketException occurred", e); + } + + this.isActive = true; + this.reliablePacketThread.Start(); + this.sendThread.Start(); + this.receiveThread.Start(); + this.processThreads.Start(); + } + + private void ReceiveLoop() + { + while (this.isActive) + { + if (this.socket.Poll(1000, SelectMode.SelectRead)) + { + if (!isActive) break; + + EndPoint remoteEP = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port); + var message = MessageReader.GetSized(this.ReceiveBufferSize); + try + { + message.Length = socket.ReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP); + } + catch (SocketException sx) + { + message.Recycle(); + if (sx.SocketErrorCode == SocketError.NotConnected) + { + this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected); + return; + } + + this.Logger.WriteError("Socket Ex in ReceiveLoop: " + sx.Message); + continue; + } + catch (Exception ex) + { + message.Recycle(); + this.Logger.WriteError("Stopped due to: " + ex.Message); + return; + } + + ConnectionId connectionId = ConnectionId.Create((IPEndPoint)remoteEP, 0); + this.ProcessIncomingMessageFromOtherThread(message, (IPEndPoint)remoteEP, connectionId); + } + } + } + + private void ProcessingLoop() + { + foreach (ReceiveMessageInfo msg in this.receiveQueue.GetConsumingEnumerable()) + { + try + { + this.ReadCallback(msg.Message, msg.Sender, msg.ConnectionId); + } + catch + { + + } + } + } + + protected void ProcessIncomingMessageFromOtherThread(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) + { + var info = new ReceiveMessageInfo() { Message = message, Sender = remoteEndPoint, ConnectionId = connectionId }; + if (!this.receiveQueue.TryAdd(info)) + { + this.Statistics.AddReceiveThreadBlocking(); + this.receiveQueue.Add(info); + } + } + + private void SendLoop() + { + foreach (SendMessageInfo msg in this.sendQueue.GetConsumingEnumerable()) + { + try + { + if (this.socket.Poll(Timeout.Infinite, SelectMode.SelectWrite)) + { + this.socket.SendTo(msg.Span.GetUnderlyingArray(), msg.Span.Offset, msg.Span.Length, SocketFlags.None, msg.Recipient); + this.Statistics.AddBytesSent(msg.Span.Length - msg.Span.Offset); + } + else + { + this.Logger.WriteError("Socket is no longer able to send"); + break; + } + } + catch (Exception e) + { + this.Logger.WriteError("Error in loop while sending: " + e.Message); + Thread.Sleep(1); + } + } + } + + protected virtual void ReadCallback(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId) + { + int bytesReceived = message.Length; + bool aware = true; + bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello; + + // If we're aware of this connection use the one already + // If this is a new client then connect with them! + ThreadLimitedUdpServerConnection connection; + if (!this.allConnections.TryGetValue(connectionId, out connection)) + { + lock (this.allConnections) + { + if (!this.allConnections.TryGetValue(connectionId, out connection)) + { + // Check for malformed connection attempts + if (!isHello) + { + message.Recycle(); + return; + } + + if (AcceptConnection != null) + { + if (!AcceptConnection(remoteEndPoint, message.Buffer, out var response)) + { + message.Recycle(); + if (response != null) + { + SendDataRaw(response, remoteEndPoint); + } + + return; + } + } + + aware = false; + connection = new ThreadLimitedUdpServerConnection(this, connectionId, remoteEndPoint, this.IPMode, this.Logger); + if (!this.allConnections.TryAdd(connectionId, connection)) + { + throw new HazelException("Failed to add a connection. This should never happen."); + } + } + } + } + + // If it's a new connection invoke the NewConnection event. + // This needs to happen before handling the message because in localhost scenarios, the ACK and + // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers + if (!aware) + { + // Skip header and hello byte; + message.Offset = 4; + message.Length = bytesReceived - 4; + message.Position = 0; + try + { + this.InvokeNewConnection(message, connection); + } + catch (Exception e) + { + this.Logger.WriteError("NewConnection handler threw: " + e); + } + } + + // Inform the connection of the buffer (new connections need to send an ack back to client) + connection.HandleReceive(message, bytesReceived); + } + + internal void SendDataRaw(byte[] response, IPEndPoint remoteEndPoint) + { + QueueRawData(response, remoteEndPoint); + } + + protected virtual void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint) + { + this.sendQueue.TryAdd(new SendMessageInfo() { Span = span, Recipient = remoteEndPoint }); + } + + /// <summary> + /// Removes a virtual connection from the list. + /// </summary> + /// <param name="endPoint">Connection key of the virtual connection.</param> + internal bool RemoveConnectionTo(ConnectionId connectionId) + { + return this.allConnections.TryRemove(connectionId, out _); + } + + /// <summary> + /// This is after all messages could be sent. Clean up anything extra. + /// </summary> + internal virtual void RemovePeerRecord(ConnectionId connectionId) + { + } + + protected override void Dispose(bool disposing) + { + foreach (var kvp in this.allConnections) + { + kvp.Value.Dispose(); + } + + bool wasActive = this.isActive; + this.isActive = false; + + // Flush outgoing packets + this.sendQueue?.CompleteAdding(); + + if (wasActive) + { + this.sendThread.Join(); + } + + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + + this.receiveQueue?.CompleteAdding(); + + if (wasActive) + { + this.reliablePacketThread.Join(); + this.receiveThread.Join(); + this.processThreads.Join(); + } + + this.receiveQueue?.Dispose(); + this.receiveQueue = null; + this.sendQueue?.Dispose(); + this.sendQueue = null; + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs new file mode 100644 index 0000000..bb139c7 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs @@ -0,0 +1,110 @@ +using System; +using System.Net; + +namespace Hazel.Udp.FewerThreads +{ + /// <summary> + /// Represents a servers's connection to a client that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + public sealed class ThreadLimitedUdpServerConnection : UdpConnection + { + public readonly DateTime CreationTime = DateTime.UtcNow; + + /// <summary> + /// The connection listener that we use the socket of. + /// </summary> + /// <remarks> + /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that + /// created this connection and is hence the listener this conenction sends and receives via. + /// </remarks> + public ThreadLimitedUdpConnectionListener Listener { get; private set; } + + public ThreadLimitedUdpConnectionListener.ConnectionId ConnectionId { get; private set; } + + /// <summary> + /// Creates a UdpConnection for the virtual connection to the endpoint. + /// </summary> + /// <param name="listener">The listener that created this connection.</param> + /// <param name="endPoint">The endpoint that we are connected to.</param> + /// <param name="IPMode">The IPMode we are connected using.</param> + internal ThreadLimitedUdpServerConnection(ThreadLimitedUdpConnectionListener listener, ThreadLimitedUdpConnectionListener.ConnectionId connectionId, IPEndPoint endPoint, IPMode IPMode, ILogger logger) + : base(logger) + { + this.Listener = listener; + this.ConnectionId = connectionId; + this.EndPoint = endPoint; + this.IPMode = IPMode; + + State = ConnectionState.Connected; + this.InitializeKeepAliveTimer(); + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { + if (bytes.Length != length) throw new ArgumentException("I made an assumption here. I hope you see this error."); + + // Hrm, well this is inaccurate for DTLS connections because the Listener does the encryption which may change the size. + // but I don't want to have a bunch of client references in the send queue... + // Does this perhaps mean the encryption is being done in the wrong class? + this.Statistics.LogPacketSend(length); + Listener.SendDataRaw(bytes, EndPoint); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void Connect(byte[] bytes = null, int timeout = 5000) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void ConnectAsync(byte[] bytes = null) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected override bool SendDisconnect(MessageWriter data = null) + { + if (!Listener.RemoveConnectionTo(this.ConnectionId)) return false; + this._state = ConnectionState.NotConnected; + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + this.WriteBytesToConnection(bytes, bytes.Length); + } + catch { } + + return true; + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + SendDisconnect(); + } + + Listener.RemovePeerRecord(this.ConnectionId); + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Hazel.csproj b/Tools/Hazel-Networking/Hazel/Hazel.csproj new file mode 100644 index 0000000..3a7ea17 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Hazel.csproj @@ -0,0 +1,14 @@ +<Project Sdk="Microsoft.NET.Sdk"> + + <PropertyGroup> + <TargetFrameworks>netstandard2.0;net472</TargetFrameworks> + <AllowUnsafeBlocks>true</AllowUnsafeBlocks> + </PropertyGroup> + + <ItemGroup> + <AssemblyAttribute Include="System.Runtime.CompilerServices.InternalsVisibleToAttribute"> + <_Parameter1>Hazel.UnitTests</_Parameter1> + </AssemblyAttribute> + </ItemGroup> + +</Project> diff --git a/Tools/Hazel-Networking/Hazel/HazelException.cs b/Tools/Hazel-Networking/Hazel/HazelException.cs new file mode 100644 index 0000000..c0db05a --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/HazelException.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + /// <summary> + /// Wrapper for exceptions thrown from Hazel. + /// </summary> + [Serializable] + public class HazelException : Exception + { + internal HazelException(string msg) : base (msg) + { + + } + + internal HazelException(string msg, Exception e) : base (msg, e) + { + + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/IPMode.cs b/Tools/Hazel-Networking/Hazel/IPMode.cs new file mode 100644 index 0000000..04c8c38 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/IPMode.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + + +namespace Hazel +{ + /// <summary> + /// Represents the IP version that a connection or listener will use. + /// </summary> + /// <remarks> + /// If you wand a client to connect or be able to connect using IPv6 then you should use <see cref="IPv4AndIPv6"/>, + /// this sets the underlying sockets to use IPv6 but still allow IPv4 sockets to connect for backwards compatability + /// and hence it is the default IPMode in most cases. + /// </remarks> + public enum IPMode + { + /// <summary> + /// Instruction to use IPv4 only, IPv6 connections will not be able to connect. + /// </summary> + IPv4, + + /// <summary> + /// Instruction to use IPv6 only, IPv4 connections will not be able to connect. IPv4 addresses can be connected + /// by converting to IPv6 addresses. + /// </summary> + IPv6 + } +} diff --git a/Tools/Hazel-Networking/Hazel/IRecyclable.cs b/Tools/Hazel-Networking/Hazel/IRecyclable.cs new file mode 100644 index 0000000..3e9769e --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/IRecyclable.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + /// <summary> + /// Interface for all items that can be returned to an object pool. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public interface IRecyclable + { + /// <summary> + /// Returns this object back to the object pool. + /// </summary> + /// <remarks> + /// <para> + /// Calling this when you are done with the object returns the object back to a pool in order to be reused. + /// This can reduce the amount of work the GC has to do dramatically but it is optional to call this. + /// </para> + /// <para> + /// Calling this indicates to Hazel that this can be reused and thus you should only call this when you are + /// completely finished with the object as the contents can be overwritten at any point after. + /// </para> + /// </remarks> + void Recycle(); + } +} diff --git a/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs new file mode 100644 index 0000000..428c567 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs @@ -0,0 +1,23 @@ +using System.Threading; + +namespace Hazel +{ + public class ListenerStatistics + { + private int _receiveThreadBlocked; + public int ReceiveThreadBlocked => this._receiveThreadBlocked; + + private long _bytesSent; + public long BytesSent => this._bytesSent; + + internal void AddReceiveThreadBlocking() + { + Interlocked.Increment(ref _receiveThreadBlocked); + } + + internal void AddBytesSent(long bytes) + { + Interlocked.Add(ref _bytesSent, bytes); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/MessageReader.cs b/Tools/Hazel-Networking/Hazel/MessageReader.cs new file mode 100644 index 0000000..bd3b0d8 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/MessageReader.cs @@ -0,0 +1,452 @@ +using System; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Hazel +{ + public class MessageReader : IRecyclable + { + public static readonly ObjectPool<MessageReader> ReaderPool = new ObjectPool<MessageReader>(() => new MessageReader()); + + public byte[] Buffer; + public byte Tag; + + public int Length; // 总长度 + public int Offset; // length和tag后面 + + public int BytesRemaining => this.Length - this.Position; + + private MessageReader Parent; + + public int Position + { + get { return this._position; } + set + { + this._position = value; + this.readHead = value + Offset; + } + } + + private int _position; + private int readHead; + + public static MessageReader GetSized(int minSize) + { + var output = ReaderPool.GetObject(); + + if (output.Buffer == null || output.Buffer.Length < minSize) + { + output.Buffer = new byte[minSize]; + } + else + { + Array.Clear(output.Buffer, 0, output.Buffer.Length); + } + + output.Offset = 0; + output.Position = 0; + output.Tag = byte.MaxValue; + return output; + } + + public static MessageReader Get(byte[] buffer) + { + var output = ReaderPool.GetObject(); + + output.Buffer = buffer; + output.Offset = 0; + output.Position = 0; + output.Length = buffer.Length; + output.Tag = byte.MaxValue; + + return output; + } + + public static MessageReader CopyMessageIntoParent(MessageReader source) + { + var output = MessageReader.GetSized(source.Length + 3); + System.Buffer.BlockCopy(source.Buffer, source.Offset - 3, output.Buffer, 0, source.Length + 3); + + output.Offset = 0; + output.Position = 0; + output.Length = source.Length + 3; + + return output; + } + + public static MessageReader Get(MessageReader source) + { + var output = MessageReader.GetSized(source.Buffer.Length); + System.Buffer.BlockCopy(source.Buffer, 0, output.Buffer, 0, source.Buffer.Length); + + output.Offset = source.Offset; + + output._position = source._position; + output.readHead = source.readHead; + + output.Length = source.Length; + output.Tag = source.Tag; + + return output; + } + + public static MessageReader Get(byte[] buffer, int offset) + { + // Ensure there is at least a header + if (offset + 3 > buffer.Length) return null; + + var output = ReaderPool.GetObject(); + + output.Buffer = buffer; + output.Offset = offset; + output.Position = 0; + + output.Length = output.ReadUInt16(); + output.Tag = output.ReadByte(); + + output.Offset += 3; + output.Position = 0; + + return output; + } + + /// <summary> + /// Produces a MessageReader using the parent's buffer. This MessageReader should **NOT** be recycled. + /// </summary> + public MessageReader ReadMessage() + { + // Ensure there is at least a header + if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}"); + + var output = new MessageReader(); + + output.Parent = this; + output.Buffer = this.Buffer; + output.Offset = this.readHead; + output.Position = 0; + + output.Length = output.ReadUInt16(); + output.Tag = output.ReadByte(); + + output.Offset += 3; + output.Position = 0; + + if (this.BytesRemaining < output.Length + 3) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {output.Length + 3} of {this.BytesRemaining}"); + + this.Position += output.Length + 3; + return output; + } + + /// <summary> + /// Produces a MessageReader with a new buffer. This MessageReader should be recycled. + /// </summary> + public MessageReader ReadMessageAsNewBuffer() + { + if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}"); + + var len = this.ReadUInt16(); + var tag = this.ReadByte(); + + if (this.BytesRemaining < len) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {len} of {this.BytesRemaining}"); + + var output = MessageReader.GetSized(len); + + Array.Copy(this.Buffer, this.readHead, output.Buffer, 0, len); + + output.Length = len; + output.Tag = tag; + + this.Position += output.Length; + return output; + } + + public MessageWriter StartWriter() + { + var output = new MessageWriter(this.Buffer); + output.Position = this.readHead; + return output; + } + + public MessageReader Duplicate() + { + var output = GetSized(this.Length); + Array.Copy(this.Buffer, this.Offset, output.Buffer, 0, this.Length); + output.Length = this.Length; + output.Offset = 0; + output.Position = 0; + + return output; + } + + public void RemoveMessage(MessageReader reader) + { + var temp = MessageReader.GetSized(reader.Buffer.Length); + try + { + var headerOffset = reader.Offset - 3; + var endOfMessage = reader.Offset + reader.Length; + var len = reader.Buffer.Length - endOfMessage; + + Array.Copy(reader.Buffer, endOfMessage, temp.Buffer, 0, len); + Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset, len); + + this.AdjustLength(reader.Offset, reader.Length + 3); + } + finally + { + temp.Recycle(); + } + } + + public void InsertMessage(MessageReader reader, MessageWriter writer) + { + var temp = MessageReader.GetSized(reader.Buffer.Length); + try + { + var headerOffset = reader.Offset - 3; + var startOfMessage = reader.Offset; + var len = reader.Buffer.Length - startOfMessage; + int writerOffset = 3; + switch (writer.SendOption) + { + case SendOption.Reliable: + writerOffset = 3; + break; + case SendOption.None: + writerOffset = 1; + break; + } + + //store the original buffer in temp + Array.Copy(reader.Buffer, headerOffset, temp.Buffer, 0, len); + + //put the contents of writer in at headerOffset + Array.Copy(writer.Buffer, writerOffset, this.Buffer, headerOffset, writer.Length-writerOffset); + + //put the original buffer in after that + Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset + (writer.Length-writerOffset), len - writer.Length); + + this.AdjustLength(-1 * reader.Offset , -1 * (writer.Length - writerOffset)); + } + finally + { + temp.Recycle(); + } + } + + private void AdjustLength(int offset, int amount) + { + if (this.readHead > offset) + { + this.Position -= amount; + } + + if (Parent != null) + { + var lengthOffset = this.Offset - 3; + var curLen = this.Buffer[lengthOffset] + | (this.Buffer[lengthOffset + 1] << 8); + + curLen -= amount; + this.Length -= amount; + + this.Buffer[lengthOffset] = (byte)curLen; + this.Buffer[lengthOffset + 1] = (byte)(this.Buffer[lengthOffset + 1] >> 8); + + Parent.AdjustLength(offset, amount); + } + } + + public void Recycle() + { + this.Parent = null; + ReaderPool.PutObject(this); + } + + #region Read Methods + public bool ReadBoolean() + { + byte val = this.FastByte(); + return val != 0; + } + + public sbyte ReadSByte() + { + return (sbyte)this.FastByte(); + } + + public byte ReadByte() + { + return this.FastByte(); + } + + public ushort ReadUInt16() + { + ushort output = + (ushort)(this.FastByte() + | this.FastByte() << 8); + return output; + } + + public short ReadInt16() + { + short output = + (short)(this.FastByte() + | this.FastByte() << 8); + return output; + } + + public uint ReadUInt32() + { + uint output = this.FastByte() + | (uint)this.FastByte() << 8 + | (uint)this.FastByte() << 16 + | (uint)this.FastByte() << 24; + + return output; + } + + public int ReadInt32() + { + int output = this.FastByte() + | this.FastByte() << 8 + | this.FastByte() << 16 + | this.FastByte() << 24; + + return output; + } + + public ulong ReadUInt64() + { + ulong output = (ulong)this.FastByte() + | (ulong)this.FastByte() << 8 + | (ulong)this.FastByte() << 16 + | (ulong)this.FastByte() << 24 + | (ulong)this.FastByte() << 32 + | (ulong)this.FastByte() << 40 + | (ulong)this.FastByte() << 48 + | (ulong)this.FastByte() << 56; + + return output; + } + + public long ReadInt64() + { + long output = (long)this.FastByte() + | (long)this.FastByte() << 8 + | (long)this.FastByte() << 16 + | (long)this.FastByte() << 24 + | (long)this.FastByte() << 32 + | (long)this.FastByte() << 40 + | (long)this.FastByte() << 48 + | (long)this.FastByte() << 56; + + return output; + } + + public unsafe float ReadSingle() + { + float output = 0; + fixed (byte* bufPtr = &this.Buffer[this.readHead]) + { + byte* outPtr = (byte*)&output; + + *outPtr = *bufPtr; + *(outPtr + 1) = *(bufPtr + 1); + *(outPtr + 2) = *(bufPtr + 2); + *(outPtr + 3) = *(bufPtr + 3); + } + + this.Position += 4; + return output; + } + + public string ReadString() + { + int len = this.ReadPackedInt32(); + if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}"); + + string output = UTF8Encoding.UTF8.GetString(this.Buffer, this.readHead, len); + + this.Position += len; + return output; + } + + public byte[] ReadBytesAndSize() + { + int len = this.ReadPackedInt32(); + if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}"); + + return this.ReadBytes(len); + } + + public byte[] ReadBytes(int length) + { + if (this.BytesRemaining < length) throw new InvalidDataException($"Read length is longer than message length: {length} of {this.BytesRemaining}"); + + byte[] output = new byte[length]; + Array.Copy(this.Buffer, this.readHead, output, 0, output.Length); + this.Position += output.Length; + return output; + } + + /// + public int ReadPackedInt32() + { + return (int)this.ReadPackedUInt32(); + } + + /// + public uint ReadPackedUInt32() + { + bool readMore = true; + int shift = 0; + uint output = 0; + + while (readMore) + { + if (this.BytesRemaining < 1) throw new InvalidDataException($"Read length is longer than message length."); + + byte b = this.ReadByte(); + if (b >= 0x80) + { + readMore = true; + b ^= 0x80; + } + else + { + readMore = false; + } + + output |= (uint)(b << shift); + shift += 7; + } + + return output; + } + #endregion + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private byte FastByte() + { + this._position++; + return this.Buffer[this.readHead++]; + } + + public unsafe static bool IsLittleEndian() + { + byte b; + unsafe + { + int i = 1; + byte* bp = (byte*)&i; + b = *bp; + } + + return b == 1; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/MessageWriter.cs b/Tools/Hazel-Networking/Hazel/MessageWriter.cs new file mode 100644 index 0000000..9caaaf2 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/MessageWriter.cs @@ -0,0 +1,365 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Hazel +{ + /// <summary> + /// 嵌套结构的Message + /// 结构: + /// ------------------------------------ + /// 2bytes (ushort) 包长度 + /// 1bytes (tag) 协议ID,在AmongUS里是tags.cs里定义的tag和subtag + /// ------------------------------------ + /// 数据 包括嵌套的子协议 + /// ------------------------------------ + /// </summary> + public class MessageWriter : IRecyclable + { + public static int BufferSize = 64000; + public static readonly ObjectPool<MessageWriter> WriterPool = new ObjectPool<MessageWriter>(() => new MessageWriter(BufferSize)); + + public byte[] Buffer; + public int Length; // 总长度 + public int Position; // 写入游标 + + public SendOption SendOption { get; private set; } + + private Stack<int> messageStarts = new Stack<int>(); + + public MessageWriter(byte[] buffer) + { + this.Buffer = buffer; + this.Length = this.Buffer.Length; + } + + /// + public MessageWriter(int bufferSize) + { + this.Buffer = new byte[bufferSize]; + } + + /// <summary> + /// 去掉header + /// </summary> + /// <param name="includeHeader"></param> + /// <returns></returns> + /// <exception cref="NotImplementedException"></exception> + public byte[] ToByteArray(bool includeHeader) + { + if (includeHeader) + { + byte[] output = new byte[this.Length]; + System.Buffer.BlockCopy(this.Buffer, 0, output, 0, this.Length); + return output; + } + else + { + switch (this.SendOption) + { + case SendOption.Reliable: + { + byte[] output = new byte[this.Length - 3]; + System.Buffer.BlockCopy(this.Buffer, 3, output, 0, this.Length - 3); + return output; + } + case SendOption.None: + { + byte[] output = new byte[this.Length - 1]; + System.Buffer.BlockCopy(this.Buffer, 1, output, 0, this.Length - 1); + return output; + } + } + } + + throw new NotImplementedException(); + } + + /// + /// <param name="sendOption">The option specifying how the message should be sent.</param> + public static MessageWriter Get(SendOption sendOption = SendOption.None) + { + var output = WriterPool.GetObject(); + output.Clear(sendOption); + + return output; + } + + public bool HasBytes(int expected) + { + if (this.SendOption == SendOption.None) + { + return this.Length > 1 + expected; + } + + return this.Length > 3 + expected; + } + + /// + public void StartMessage(byte typeFlag) + { + var messageStart = this.Position; + messageStarts.Push(messageStart); + this.Buffer[messageStart] = 0; + this.Buffer[messageStart + 1] = 0; + this.Position += 2; + this.Write(typeFlag); + } + + /// + public void EndMessage() + { + var lastMessageStart = messageStarts.Pop(); + ushort length = (ushort)(this.Position - lastMessageStart - 3); // Minus length and type byte + this.Buffer[lastMessageStart] = (byte)length; + this.Buffer[lastMessageStart + 1] = (byte)(length >> 8); + } + + /// + public void CancelMessage() + { + this.Position = this.messageStarts.Pop(); + this.Length = this.Position; + } + + public void Clear(SendOption sendOption) + { + Array.Clear(this.Buffer, 0, this.Buffer.Length); + this.messageStarts.Clear(); + this.SendOption = sendOption; + this.Buffer[0] = (byte)sendOption; + switch (sendOption) + { + default: + case SendOption.None: + this.Length = this.Position = 1; + break; + case SendOption.Reliable: + this.Length = this.Position = 3; + break; + } + } + + /// + public void Recycle() + { + this.Position = this.Length = 0; + WriterPool.PutObject(this); + } + + #region WriteMethods + + public void CopyFrom(MessageReader target) + { + int offset, length; + if (target.Tag == byte.MaxValue) + { + offset = target.Offset; + length = target.Length; + } + else + { + offset = target.Offset - 3; + length = target.Length + 3; + } + + System.Buffer.BlockCopy(target.Buffer, offset, this.Buffer, this.Position, length); + this.Position += length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(bool value) + { + this.Buffer[this.Position++] = (byte)(value ? 1 : 0); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(sbyte value) + { + this.Buffer[this.Position++] = (byte)value; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte value) + { + this.Buffer[this.Position++] = value; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(short value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(ushort value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(uint value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(int value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(ulong value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + this.Buffer[this.Position++] = (byte)(value >> 32); + this.Buffer[this.Position++] = (byte)(value >> 40); + this.Buffer[this.Position++] = (byte)(value >> 48); + this.Buffer[this.Position++] = (byte)(value >> 56); + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(long value) + { + this.Buffer[this.Position++] = (byte)value; + this.Buffer[this.Position++] = (byte)(value >> 8); + this.Buffer[this.Position++] = (byte)(value >> 16); + this.Buffer[this.Position++] = (byte)(value >> 24); + this.Buffer[this.Position++] = (byte)(value >> 32); + this.Buffer[this.Position++] = (byte)(value >> 40); + this.Buffer[this.Position++] = (byte)(value >> 48); + this.Buffer[this.Position++] = (byte)(value >> 56); + if (this.Position > this.Length) this.Length = this.Position; + } + + public unsafe void Write(float value) + { + fixed (byte* ptr = &this.Buffer[this.Position]) + { + byte* valuePtr = (byte*)&value; + + *ptr = *valuePtr; + *(ptr + 1) = *(valuePtr + 1); + *(ptr + 2) = *(valuePtr + 2); + *(ptr + 3) = *(valuePtr + 3); + } + + this.Position += 4; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(string value) + { + var bytes = UTF8Encoding.UTF8.GetBytes(value); + this.WritePacked(bytes.Length); + this.Write(bytes); + } + + public void WriteBytesAndSize(byte[] bytes) + { + this.WritePacked((uint)bytes.Length); + this.Write(bytes); + } + + public void WriteBytesAndSize(byte[] bytes, int length) + { + this.WritePacked((uint)length); + this.Write(bytes, length); + } + + public void WriteBytesAndSize(byte[] bytes, int offset, int length) + { + this.WritePacked((uint)length); + this.Write(bytes, offset, length); + } + + public void Write(byte[] bytes) + { + Array.Copy(bytes, 0, this.Buffer, this.Position, bytes.Length); + this.Position += bytes.Length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte[] bytes, int offset, int length) + { + Array.Copy(bytes, offset, this.Buffer, this.Position, length); + this.Position += length; + if (this.Position > this.Length) this.Length = this.Position; + } + + public void Write(byte[] bytes, int length) + { + Array.Copy(bytes, 0, this.Buffer, this.Position, length); + this.Position += length; + if (this.Position > this.Length) this.Length = this.Position; + } + + /// + public void WritePacked(int value) + { + this.WritePacked((uint)value); + } + + /// + public void WritePacked(uint value) + { + do + { + byte b = (byte)(value & 0xFF); + if (value >= 0x80) + { + b |= 0x80; + } + + this.Write(b); + value >>= 7; + } while (value > 0); + } + #endregion + + public void Write(MessageWriter msg, bool includeHeader) + { + int offset = 0; + if (!includeHeader) + { + switch (msg.SendOption) + { + case SendOption.None: + offset = 1; + break; + case SendOption.Reliable: + offset = 3; + break; + } + } + + this.Write(msg.Buffer, offset, msg.Length - offset); + } + + public unsafe static bool IsLittleEndian() + { + byte b; + unsafe + { + int i = 1; + byte* bp = (byte*)&i; + b = *bp; + } + + return b == 1; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnection.cs b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs new file mode 100644 index 0000000..d1de8a8 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs @@ -0,0 +1,117 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Text; + + +namespace Hazel +{ + public enum HazelInternalErrors + { + SocketExceptionSend, + SocketExceptionReceive, + ReceivedZeroBytes, + PingsWithoutResponse, + ReliablePacketWithoutResponse, + ConnectionDisconnected, + DtlsNegotiationFailed + } + + /// <summary> + /// Abstract base class for a <see cref="Connection"/> to a remote end point via a network protocol like TCP or UDP. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public abstract class NetworkConnection : Connection + { + /// <summary> + /// An event that gives us a chance to send well-formed disconnect messages to clients when an internal disconnect happens. + /// </summary> + public Func<HazelInternalErrors, MessageWriter> OnInternalDisconnect; + + public virtual float AveragePingMs { get; } + + public long GetIP4Address() + { + if (IPMode == IPMode.IPv4) + { + return this.EndPoint.Address.Address; + } + else + { + var bytes = this.EndPoint.Address.GetAddressBytes(); + return BitConverter.ToInt64(bytes, bytes.Length - 8); + } + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected abstract bool SendDisconnect(MessageWriter writer); + + /// <summary> + /// Called when the socket has been disconnected at the remote host. + /// </summary> + protected void DisconnectRemote(string reason, MessageReader reader) + { + if (this.SendDisconnect(null)) + { + try + { + InvokeDisconnected(reason, reader); + } + catch { } + } + + this.Dispose(); + } + + /// <summary> + /// Called when socket is disconnected internally + /// </summary> + internal void DisconnectInternal(HazelInternalErrors error, string reason) + { + var handler = this.OnInternalDisconnect; + if (handler != null) + { + MessageWriter messageToRemote = handler(error); + if (messageToRemote != null) + { + try + { + Disconnect(reason, messageToRemote); + } + finally + { + messageToRemote.Recycle(); + } + } + else + { + Disconnect(reason); + } + } + else + { + Disconnect(reason); + } + } + + /// <summary> + /// Called when the socket has been disconnected locally. + /// </summary> + public override void Disconnect(string reason, MessageWriter writer = null) + { + if (this.SendDisconnect(writer)) + { + try + { + InvokeDisconnected(reason, null); + } + catch { } + } + + this.Dispose(); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs new file mode 100644 index 0000000..af26c4c --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Text; + + +namespace Hazel +{ + /// <summary> + /// Abstract base class for a <see cref="ConnectionListener"/> for network based connections. + /// </summary> + /// <threadsafety static="true" instance="true"/> + public abstract class NetworkConnectionListener : ConnectionListener + { + /// <summary> + /// The local end point the listener is listening for new clients on. + /// </summary> + public IPEndPoint EndPoint { get; protected set; } + + /// <summary> + /// The <see cref="IPMode">IPMode</see> the listener is listening for new clients on. + /// </summary> + public IPMode IPMode { get; protected set; } + } +} diff --git a/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs new file mode 100644 index 0000000..c3fd62f --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs @@ -0,0 +1,22 @@ +namespace Hazel +{ + public struct NewConnectionEventArgs + { + /// <summary> + /// The data received from the client in the handshake. + /// You must not recycle this. If you need the message outside of a callback, you should copy it. + /// </summary> + public readonly MessageReader HandshakeData; + + /// <summary> + /// The <see cref="Connection"/> to the new client. + /// </summary> + public readonly Connection Connection; + + public NewConnectionEventArgs(MessageReader handshakeData, Connection connection) + { + this.HandshakeData = handshakeData; + this.Connection = connection; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/ObjectPool.cs b/Tools/Hazel-Networking/Hazel/ObjectPool.cs new file mode 100644 index 0000000..510e55a --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/ObjectPool.cs @@ -0,0 +1,108 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; + +namespace Hazel +{ + /// <summary> + /// A fairly simple object pool for items that will be created a lot. + /// </summary> + /// <typeparam name="T">The type that is pooled.</typeparam> + /// <threadsafety static="true" instance="true"/> + public sealed class ObjectPool<T> where T : IRecyclable + { + private int numberCreated; + public int NumberCreated { get { return numberCreated; } } + + public int NumberInUse { get { return this.inuse.Count; } } + public int NumberNotInUse { get { return this.pool.Count; } } + public int Size { get { return this.NumberInUse + this.NumberNotInUse; } } + +#if HAZEL_BAG + private readonly ConcurrentBag<T> pool = new ConcurrentBag<T>(); +#else + private readonly List<T> pool = new List<T>(); +#endif + + // Unavailable objects + private readonly ConcurrentDictionary<T, bool> inuse = new ConcurrentDictionary<T, bool>(); + + /// <summary> + /// The generator for creating new objects. + /// </summary> + /// <returns></returns> + private readonly Func<T> objectFactory; + + /// <summary> + /// Internal constructor for our ObjectPool. + /// </summary> + internal ObjectPool(Func<T> objectFactory) + { + this.objectFactory = objectFactory; + } + + /// <summary> + /// Returns a pooled object of type T, if none are available another is created. + /// </summary> + /// <returns>An instance of T.</returns> + internal T GetObject() + { +#if HAZEL_BAG + if (!pool.TryTake(out T item)) + { + Interlocked.Increment(ref numberCreated); + item = objectFactory.Invoke(); + } +#else + T item; + lock (this.pool) + { + if (this.pool.Count > 0) + { + var idx = this.pool.Count - 1; + item = this.pool[idx]; + this.pool.RemoveAt(idx); + } + else + { + Interlocked.Increment(ref numberCreated); + item = objectFactory.Invoke(); + } + } +#endif + + if (!inuse.TryAdd(item, true)) + { + throw new Exception("Duplicate pull " + typeof(T).Name); + } + + return item; + } + + /// <summary> + /// Returns an object to the pool. + /// </summary> + /// <param name="item">The item to return.</param> + internal void PutObject(T item) + { + if (inuse.TryRemove(item, out bool b)) + { +#if HAZEL_BAG + pool.Add(item); +#else + lock (this.pool) + { + pool.Add(item); + } +#endif + } + else + { +#if DEBUG + throw new Exception("Duplicate add " + typeof(T).Name); +#endif + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/SendErrors.cs b/Tools/Hazel-Networking/Hazel/SendErrors.cs new file mode 100644 index 0000000..6871c6a --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/SendErrors.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + [Flags] + public enum SendErrors + { + None, + Disconnected, + Unknown + } +} diff --git a/Tools/Hazel-Networking/Hazel/SendOption.cs b/Tools/Hazel-Networking/Hazel/SendOption.cs new file mode 100644 index 0000000..c2ffb22 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/SendOption.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Hazel +{ + /// <summary> + /// Specifies how a message should be sent between connections. + /// </summary> + [Flags] + public enum SendOption : byte + { + /// <summary> + /// Requests unreliable delivery with no framentation. + /// </summary> + /// <remarks> + /// Sending data using unreliable delivery means that data is not guaranteed to arrive at it's destination nor is + /// it guarenteed to arrive only once. However, unreliable delivery can be faster than other methods and it + /// typically requires a smaller number of protocol bytes than other methods. There is also typically less + /// processing involved and less memory needed as packets are not stored once sent. + /// </remarks> + None = 0, + + /// <summary> + /// Requests data be sent reliably but with no fragmentation. + /// </summary> + /// <remarks> + /// Sending data reliably means that data is guarenteed to arrive and to arrive only once. Reliable delivery + /// typically requires more processing, more memory (as packets need to be stored in case they need resending), + /// a larger number of protocol bytes and can be slower than unreliable delivery. + /// </remarks> + Reliable = 1, + } +} diff --git a/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs new file mode 100644 index 0000000..3c7abcf --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs @@ -0,0 +1,65 @@ +using System; + +namespace Hazel +{ + public interface ILogger + { + void WriteVerbose(string msg); + void WriteError(string msg); + void WriteWarning(string msg); + void WriteInfo(string msg); + } + + public class NullLogger : ILogger + { + public static readonly NullLogger Instance = new NullLogger(); + + public void WriteVerbose(string msg) + { + } + + public void WriteError(string msg) + { + } + + public void WriteWarning(string msg) + { + } + + public void WriteInfo(string msg) + { + } + } + + public class ConsoleLogger : ILogger + { + private bool verbose; + public ConsoleLogger(bool verbose) + { + this.verbose = verbose; + } + + public void WriteVerbose(string msg) + { + if (this.verbose) + { + Console.WriteLine($"{DateTime.Now} [VERBOSE] {msg}"); + } + } + + public void WriteWarning(string msg) + { + Console.WriteLine($"{DateTime.Now} [WARN] {msg}"); + } + + public void WriteError(string msg) + { + Console.WriteLine($"{DateTime.Now} [ERROR] {msg}"); + } + + public void WriteInfo(string msg) + { + Console.WriteLine($"{DateTime.Now} [INFO] {msg}"); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs new file mode 100644 index 0000000..d856823 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs @@ -0,0 +1,158 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; + +namespace Hazel.UPnP +{ + internal class NetUtility + { + private static IList<NetworkInterface> GetValidNetworkInterfaces() + { + var nics = NetworkInterface.GetAllNetworkInterfaces(); + if (nics == null || nics.Length < 1) + return new NetworkInterface[0]; + + var validInterfaces = new List<NetworkInterface>(nics.Length); + + NetworkInterface best = null; + foreach (NetworkInterface adapter in nics) + { + if (adapter.NetworkInterfaceType == NetworkInterfaceType.Loopback || adapter.NetworkInterfaceType == NetworkInterfaceType.Unknown) + continue; + if (!adapter.Supports(NetworkInterfaceComponent.IPv4) && !adapter.Supports(NetworkInterfaceComponent.IPv6)) + continue; + if (best == null) + best = adapter; + if (adapter.OperationalStatus != OperationalStatus.Up) + continue; + + // make sure this adapter has any ip addresses + IPInterfaceProperties properties = adapter.GetIPProperties(); + foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses) + { + if (unicastAddress != null && unicastAddress.Address != null) + { + // Yes it does, add this network interface. + validInterfaces.Add(adapter); + break; + } + } + } + + if (validInterfaces.Count == 0 && best != null) + validInterfaces.Add(best); + + return validInterfaces; + } + + /// <summary> + /// Gets the addresses from all active network interfaces, but at most one per interface. + /// </summary> + /// <param name="addressFamily">The <see cref="AddressFamily"/> of the addresses to return</param> + /// <returns>An <see cref="ICollection{T}"/> of <see cref="UnicastIPAddressInformation"/>.</returns> + public static ICollection<UnicastIPAddressInformation> GetAddressesFromNetworkInterfaces(AddressFamily addressFamily) + { + var unicastAddresses = new List<UnicastIPAddressInformation>(); + + foreach (NetworkInterface adapter in GetValidNetworkInterfaces()) + { + IPInterfaceProperties properties = adapter.GetIPProperties(); + foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses) + { + if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == addressFamily) + { + unicastAddresses.Add(unicastAddress); + break; + } + } + } + + return unicastAddresses; + } + + /// <summary> + /// Gets my local IPv4 address (not necessarily external) and subnet mask + /// </summary> + public static IPAddress GetMyAddress(out IPAddress mask) + { + var networkInterfaces = GetValidNetworkInterfaces(); + IPInterfaceProperties properties = null; + + if (networkInterfaces.Count > 0) + properties = networkInterfaces[0]?.GetIPProperties(); + + if (properties != null) + { + foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses) + { + if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork) + { + mask = unicastAddress.IPv4Mask; + return unicastAddress.Address; + } + } + } + + mask = null; + return null; + } + + /// <summary> + /// Gets the broadcast address for the first network interface or, if not able to, + /// the limited broadcast address. + /// </summary> + /// <returns>An <see cref="IPAddress"/> for broadcasting.</returns> + public static IPAddress GetBroadcastAddress() + { + var networkInterfaces = GetValidNetworkInterfaces(); + IPInterfaceProperties properties = null; + + if (networkInterfaces.Count > 0) + properties = networkInterfaces[0]?.GetIPProperties(); + + if (properties != null) + { + foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses) + { + IPAddress ipAddress = GetBroadcastAddress(unicastAddress); + if (ipAddress != null) + { + return ipAddress; + } + } + } + + return IPAddress.Broadcast; + } + + /// <summary> + /// Gets the broadcast address for the given <paramref name="unicastAddress"/>. + /// </summary> + /// <param name="unicastAddress">A <see cref="UnicastIPAddressInformation"/></param> + /// <returns>An <see cref="IPAddress"/> for broadcasting, null if the <paramref name="unicastAddress"/> + /// is not an IPv4 address.</returns> + public static IPAddress GetBroadcastAddress(UnicastIPAddressInformation unicastAddress) + { + if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork) + { + var mask = unicastAddress.IPv4Mask; + byte[] ipAdressBytes = unicastAddress.Address.GetAddressBytes(); + byte[] subnetMaskBytes = mask.GetAddressBytes(); + + if (ipAdressBytes.Length != subnetMaskBytes.Length) + throw new ArgumentException("Lengths of IP address and subnet mask do not match."); + + byte[] broadcastAddress = new byte[ipAdressBytes.Length]; + for (int i = 0; i < broadcastAddress.Length; i++) + { + broadcastAddress[i] = (byte)(ipAdressBytes[i] | (subnetMaskBytes[i] ^ 255)); + } + return new IPAddress(broadcastAddress); + } + + return null; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs new file mode 100644 index 0000000..771709e --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs @@ -0,0 +1,347 @@ +using System; +using System.IO; +using System.Xml; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +namespace Hazel.UPnP +{ + /// <summary> + /// Status of the UPnP capabilities + /// </summary> + public enum UPnPStatus + { + /// <summary> + /// Still discovering UPnP capabilities + /// </summary> + Discovering, + + /// <summary> + /// UPnP is not available + /// </summary> + NotAvailable, + + /// <summary> + /// UPnP is available and ready to use + /// </summary> + Available + } + + public class UPnPHelper : IDisposable + { + private const int DiscoveryTimeOutMs = 1000; + + private string serviceUrl; + private string serviceName = ""; + + private ManualResetEvent discoveryComplete = new ManualResetEvent(false); + private Socket socket; + + private DateTime discoveryResponseDeadline; + + private EndPoint ep; + private byte[] buffer; + + private ILogger logger; + + /// <summary> + /// Status of the UPnP capabilities of this NetPeer + /// </summary> + public UPnPStatus Status { get; private set; } + + public UPnPHelper(ILogger logger) + { + this.logger = logger; + + this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + this.socket.EnableBroadcast = true; + this.socket.MulticastLoopback = false; + + this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, 1); + this.socket.Bind(new IPEndPoint(IPAddress.Any, 0)); + + this.ep = new IPEndPoint(IPAddress.Any, 1900); + this.buffer = new byte[ushort.MaxValue]; + + ListenForUPnP(); + + this.Discover(); + } + + private void ListenForUPnP() + { + try + { + socket.BeginReceiveFrom(this.buffer, 0, this.buffer.Length, SocketFlags.None, ref ep, HandleMessage, null); + } + catch(Exception e) + { + this.logger.WriteInfo("Exception listening for UPnP: " + e.Message); + } + } + + private void HandleMessage(IAsyncResult ar) + { + int len; + try + { + len = this.socket.EndReceiveFrom(ar, ref ep); + } + catch + { + return; + } + + string resp = System.Text.Encoding.UTF8.GetString(buffer, 0, len); + if (resp.Contains("upnp:rootdevice") || resp.Contains("UPnP/1.0")) + { + var locationStart = resp.IndexOf("location:", StringComparison.OrdinalIgnoreCase); + if (locationStart >= 0) + { + locationStart += 10; + var locationEnd = resp.IndexOf("\r", locationStart); + + resp = resp.Substring(locationStart, locationEnd - locationStart); + if (!ExtractServiceUrl(resp)) + { + ListenForUPnP(); + } + } + else + { + ListenForUPnP(); + } + } + else + { + ListenForUPnP(); + } + } + + internal void Discover() + { + string str = +"M-SEARCH * HTTP/1.1\r\n" + +"HOST: 239.255.255.250:1900\r\n" + +"ST:upnp:rootdevice\r\n" + +"MAN:\"ssdp:discover\"\r\n" + +"MX:3\r\n\r\n"; + + discoveryResponseDeadline = DateTime.UtcNow.AddSeconds(6); + Status = UPnPStatus.Discovering; + + byte[] buffer = System.Text.Encoding.UTF8.GetBytes(str); + + this.logger.WriteInfo("Attempting UPnP discovery"); + + socket.SendTo(buffer, new IPEndPoint(NetUtility.GetBroadcastAddress(), 1900)); + } + + internal bool ExtractServiceUrl(string resp) + { + try + { + XmlDocument desc = new XmlDocument(); + using (var response = WebRequest.Create(resp).GetResponse()) + { + desc.Load(response.GetResponseStream()); + } + + XmlNamespaceManager nsMgr = new XmlNamespaceManager(desc.NameTable); + nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0"); + XmlNode typen = desc.SelectSingleNode("//tns:device/tns:deviceType/text()", nsMgr); + if (!typen.Value.Contains("InternetGatewayDevice")) + return false; + + serviceName = "WANIPConnection"; + XmlNode node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr); + if (node == null) + { + //try another service name + serviceName = "WANPPPConnection"; + node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr); + if (node == null) + return false; + } + + serviceUrl = CombineUrls(resp, node.Value); + this.logger.WriteInfo("UPnP service ready"); + Status = UPnPStatus.Available; + discoveryComplete.Set(); + return true; + } + catch (Exception e) + { + this.logger.WriteError("Exception while parsing UPnP Service URL: " + e.Message); + return false; + } + } + + private static string CombineUrls(string gatewayURL, string subURL) + { + // Is Control URL an absolute URL? + if (subURL.Contains("http:") || subURL.Contains(".")) + return subURL; + + gatewayURL = gatewayURL.Replace("http://", ""); // strip any protocol + int n = gatewayURL.IndexOf("/"); + if (n >= 0) + { + gatewayURL = gatewayURL.Substring(0, n); // Use first portion of URL + } + + return "http://" + gatewayURL + subURL; + } + + private bool CheckAvailability() + { + switch (Status) + { + case UPnPStatus.NotAvailable: + return false; + case UPnPStatus.Available: + return true; + case UPnPStatus.Discovering: + while (!discoveryComplete.WaitOne(DiscoveryTimeOutMs)) + { + if (DateTime.UtcNow > discoveryResponseDeadline) + { + Status = UPnPStatus.NotAvailable; + return false; + } + } + + return true; + } + + return false; + } + + /// <summary> + /// Add a forwarding rule to the router using UPnP + /// </summary> + /// <param name="externalPort">The external, WAN facing, port</param> + /// <param name="description">A description for the port forwarding rule</param> + /// <param name="internalPort">The port on the client machine to send traffic to</param> + /// <param name="durationSeconds">The lease duration on the port forwarding rule, in seconds. 0 for indefinite.</param> + public bool ForwardPort(int externalPort, string description, int internalPort = 0, int durationSeconds = 0) + { + if (!CheckAvailability()) + return false; + + if (internalPort == 0) + internalPort = externalPort; + + try + { + var client = NetUtility.GetMyAddress(out _); + if (client == null) + return false; + + SOAPRequest(serviceUrl, + $"<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" + + "<NewRemoteHost></NewRemoteHost>" + + $"<NewExternalPort>{externalPort}</NewExternalPort>" + + "<NewProtocol>UDP</NewProtocol>" + + $"<NewInternalPort>{internalPort}</NewInternalPort>" + + $"<NewInternalClient>{client}</NewInternalClient>" + + "<NewEnabled>1</NewEnabled>" + + $"<NewPortMappingDescription>{description}</NewPortMappingDescription>" + + $"<NewLeaseDuration>{durationSeconds}</NewLeaseDuration>" + + "</u:AddPortMapping>", + "AddPortMapping"); + + this.logger.WriteInfo("Sent UPnP port forward request."); + return true; + } + catch (Exception ex) + { + this.logger.WriteError("UPnP port forward failed: " + ex.Message); + return false; + } + } + + /// <summary> + /// Delete a forwarding rule from the router using UPnP + /// </summary> + /// <param name="externalPort">The external, 'internet facing', port</param> + public bool DeleteForwardingRule(int externalPort) + { + if (!CheckAvailability()) + return false; + + try + { + SOAPRequest(serviceUrl, + $"<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" + + "<NewRemoteHost></NewRemoteHost>" + + $"<NewExternalPort>{externalPort}</NewExternalPort>" + + $"<NewProtocol>UDP</NewProtocol>" + + "</u:DeletePortMapping>", "DeletePortMapping"); + return true; + } + catch (Exception ex) + { + // m_peer.LogWarning("UPnP delete forwarding rule failed: " + ex.Message); + return false; + } + } + + /// <summary> + /// Retrieve the extern ip using UPnP + /// </summary> + public IPAddress GetExternalIP() + { + if (!CheckAvailability()) + return null; + try + { + XmlDocument xdoc = SOAPRequest(serviceUrl, "<u:GetExternalIPAddress xmlns:u=\"urn:schemas-upnp-org:service:" + serviceName + ":1\">" + + "</u:GetExternalIPAddress>", "GetExternalIPAddress"); + XmlNamespaceManager nsMgr = new XmlNamespaceManager(xdoc.NameTable); + nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0"); + string IP = xdoc.SelectSingleNode("//NewExternalIPAddress/text()", nsMgr).Value; + return IPAddress.Parse(IP); + } + catch (Exception ex) + { + // m_peer.LogWarning("Failed to get external IP: " + ex.Message); + return null; + } + } + + private XmlDocument SOAPRequest(string url, string soap, string function) + { + string req = +"<?xml version=\"1.0\"?>" + +"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">" + +$"<s:Body>{soap}</s:Body>" + +"</s:Envelope>"; + + WebRequest r = HttpWebRequest.Create(url); + r.Headers.Add("SOAPACTION", $"\"urn:schemas-upnp-org:service:{serviceName}:1#{function}\""); + r.ContentType = "text/xml; charset=\"utf-8\""; + r.Method = "POST"; + + byte[] b = System.Text.Encoding.UTF8.GetBytes(req); + r.ContentLength = b.Length; + r.GetRequestStream().Write(b, 0, b.Length); + + using (WebResponse wres = r.GetResponse()) + { + XmlDocument resp = new XmlDocument(); + Stream ress = wres.GetResponseStream(); + resp.Load(ress); + return resp; + } + } + + public void Dispose() + { + this.discoveryComplete.Dispose(); + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + this.socket.Dispose(); + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs new file mode 100644 index 0000000..74786d8 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + + +namespace Hazel.Udp +{ + /// <summary> + /// Extra internal states for SendOption enumeration when using UDP. + /// </summary> + public enum UdpSendOption : byte + { + /// <summary> + /// Hello message for initiating communication. + /// </summary> + Hello = 8, + + /// <summary> + /// A single byte of continued existence + /// </summary> + Ping = 12, + + /// <summary> + /// Message for discontinuing communication. + /// </summary> + Disconnect = 9, + + /// <summary> + /// Message acknowledging the receipt of a message. + /// </summary> + Acknowledgement = 10, + + /// <summary> + /// Message that is part of a larger, fragmented message. + /// </summary> + Fragment = 11, + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs new file mode 100644 index 0000000..13b8d0b --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs @@ -0,0 +1,157 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading; + +namespace Hazel.Udp +{ + public class BroadcastPacket + { + public string Data; + public DateTime ReceiveTime; + public IPEndPoint Sender; + + public BroadcastPacket(string data, IPEndPoint sender) + { + this.Data = data; + this.Sender = sender; + this.ReceiveTime = DateTime.Now; + } + + public string GetAddress() + { + return this.Sender.Address.ToString(); + } + } + + public class UdpBroadcastListener : IDisposable + { + private Socket socket; + private EndPoint endpoint; + private Action<string> logger; + + private byte[] buffer = new byte[1024]; + + private List<BroadcastPacket> packets = new List<BroadcastPacket>(); + + public bool Running { get; private set; } + + /// + public UdpBroadcastListener(int port, Action<string> logger = null) + { + this.logger = logger; + this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + this.socket.EnableBroadcast = true; + this.socket.MulticastLoopback = false; + this.endpoint = new IPEndPoint(IPAddress.Any, port); + this.socket.Bind(this.endpoint); + } + + /// + public void StartListen() + { + if (this.Running) return; + this.Running = true; + + try + { + EndPoint endpt = new IPEndPoint(IPAddress.Any, 0); + this.socket.BeginReceiveFrom(buffer, 0, buffer.Length, SocketFlags.None, ref endpt, this.HandleData, null); + } + catch (NullReferenceException) { } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + this.Dispose(); + } + } + + private void HandleData(IAsyncResult result) + { + this.Running = false; + + int numBytes; + EndPoint endpt = new IPEndPoint(IPAddress.Any, 0); + try + { + numBytes = this.socket.EndReceiveFrom(result, ref endpt); + } + catch (NullReferenceException) + { + // Already disposed + return; + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + this.Dispose(); + return; + } + + if (numBytes < 3 + || buffer[0] != 4 || buffer[1] != 2) + { + this.StartListen(); + return; + } + + IPEndPoint ipEnd = (IPEndPoint)endpt; + string data = UTF8Encoding.UTF8.GetString(buffer, 2, numBytes - 2); + int dataHash = data.GetHashCode(); + + lock (packets) + { + bool found = false; + for (int i = 0; i < this.packets.Count; ++i) + { + var pkt = this.packets[i]; + if (pkt == null || pkt.Data == null) + { + this.packets.RemoveAt(i); + i--; + continue; + } + + if (pkt.Data.GetHashCode() == dataHash + && pkt.Sender.Equals(ipEnd)) + { + this.packets[i].ReceiveTime = DateTime.Now; + break; + } + } + + if (!found) + { + this.packets.Add(new BroadcastPacket(data, ipEnd)); + } + } + + this.StartListen(); + } + + /// + public BroadcastPacket[] GetPackets() + { + lock (this.packets) + { + var output = this.packets.ToArray(); + this.packets.Clear(); + return output; + } + } + + /// + public void Dispose() + { + if (this.socket != null) + { + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + this.socket = null; + } + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs new file mode 100644 index 0000000..8877f86 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs @@ -0,0 +1,127 @@ +using Hazel.UPnP; +using System; +using System.Net; +using System.Net.Sockets; +using System.Text; + +namespace Hazel.Udp +{ + public class UdpBroadcaster : IDisposable + { + private SocketBroadcast[] socketBroadcasts; + private byte[] data; + private Action<string> logger; + + /// + public UdpBroadcaster(int port, Action<string> logger = null) + { + this.logger = logger; + + var addresses = NetUtility.GetAddressesFromNetworkInterfaces(AddressFamily.InterNetwork); + this.socketBroadcasts = new SocketBroadcast[addresses.Count > 0 ? addresses.Count : 1]; + + int count = 0; + foreach (var addressInformation in addresses) + { + Socket socket = CreateSocket(new IPEndPoint(addressInformation.Address, 0)); + IPAddress broadcast = NetUtility.GetBroadcastAddress(addressInformation); + + this.socketBroadcasts[count] = new SocketBroadcast(socket, new IPEndPoint(broadcast, port)); + count++; + } + if (count == 0) + { + Socket socket = CreateSocket(new IPEndPoint(IPAddress.Any, 0)); + + this.socketBroadcasts[0] = new SocketBroadcast(socket, new IPEndPoint(IPAddress.Broadcast, port)); + } + } + + private static Socket CreateSocket(IPEndPoint endPoint) + { + var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + socket.EnableBroadcast = true; + socket.MulticastLoopback = false; + socket.Bind(endPoint); + + return socket; + } + + /// + public void SetData(string data) + { + int len = UTF8Encoding.UTF8.GetByteCount(data); + this.data = new byte[len + 2]; + this.data[0] = 4; + this.data[1] = 2; + + UTF8Encoding.UTF8.GetBytes(data, 0, data.Length, this.data, 2); + } + + /// + public void Broadcast() + { + if (this.data == null) + { + return; + } + + foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts) + { + try + { + Socket socket = socketBroadcast.Socket; + socket.BeginSendTo(data, 0, data.Length, SocketFlags.None, socketBroadcast.Broadcast, this.FinishSendTo, socket); + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + } + } + } + + private void FinishSendTo(IAsyncResult evt) + { + try + { + Socket socket = (Socket)evt.AsyncState; + socket.EndSendTo(evt); + } + catch (Exception e) + { + this.logger?.Invoke("BroadcastListener: " + e); + } + } + + /// + public void Dispose() + { + if (this.socketBroadcasts != null) + { + foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts) + { + Socket socket = socketBroadcast.Socket; + if (socket != null) + { + try { socket.Shutdown(SocketShutdown.Both); } catch { } + try { socket.Close(); } catch { } + try { socket.Dispose(); } catch { } + } + } + Array.Clear(this.socketBroadcasts, 0, this.socketBroadcasts.Length); + } + } + + private struct SocketBroadcast + { + public Socket Socket; + public IPEndPoint Broadcast; + + public SocketBroadcast(Socket socket, IPEndPoint broadcast) + { + Socket = socket; + Broadcast = broadcast; + } + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs new file mode 100644 index 0000000..f6da329 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs @@ -0,0 +1,364 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; + + +namespace Hazel.Udp +{ + /// <summary> + /// Represents a client's connection to a server that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + public sealed class UdpClientConnection : UdpConnection + { + /// <summary> + /// The max size Hazel attempts to read from the network. + /// Defaults to 8096. + /// </summary> + /// <remarks> + /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo. + /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5 + /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant + /// for transferring large contiguous blocks of data, so... please don't? + /// </remarks> + public int ReceiveBufferSize = 8096; + + /// <summary> + /// The socket we're connected via. + /// </summary> + private Socket socket; + + /// <summary> + /// Reset event that is triggered when the connection is marked Connected. + /// </summary> + private ManualResetEvent connectWaitLock = new ManualResetEvent(false); + + private Timer reliablePacketTimer; + +#if DEBUG + public event Action<byte[], int> DataSentRaw; + public event Action<byte[], int> DataReceivedRaw; +#endif + + /// <summary> + /// Creates a new UdpClientConnection. + /// </summary> + /// <param name="remoteEndPoint">A <see cref="NetworkEndPoint"/> to connect to.</param> + public UdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) + : base(logger) + { + this.EndPoint = remoteEndPoint; + this.IPMode = ipMode; + + this.socket = CreateSocket(ipMode); + + reliablePacketTimer = new Timer(ManageReliablePacketsInternal, null, 100, Timeout.Infinite); + this.InitializeKeepAliveTimer(); + } + + ~UdpClientConnection() + { + this.Dispose(false); + } + + private void ManageReliablePacketsInternal(object state) + { + base.ManageReliablePackets(); + try + { + reliablePacketTimer.Change(100, Timeout.Infinite); + } + catch { } + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { +#if DEBUG + if (TestLagMs > 0) + { + ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); }); + } + else +#endif + { + WriteBytesToConnectionReal(bytes, length); + } + } + + private void WriteBytesToConnectionReal(byte[] bytes, int length) + { +#if DEBUG + DataSentRaw?.Invoke(bytes, length); +#endif + + try + { + this.Statistics.LogPacketSend(length); + socket.BeginSendTo( + bytes, + 0, + length, + SocketFlags.None, + EndPoint, + HandleSendTo, + null); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + private void HandleSendTo(IAsyncResult result) + { + try + { + socket.EndSendTo(result); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + /// <inheritdoc /> + public override void Connect(byte[] bytes = null, int timeout = 5000) + { + this.ConnectAsync(bytes); + + //Wait till hello packet is acknowledged and the state is set to Connected + bool timedOut = !WaitOnConnect(timeout); + + //If we timed out raise an exception + if (timedOut) + { + Dispose(); + throw new HazelException("Connection attempt timed out."); + } + } + + /// <inheritdoc /> + public override void ConnectAsync(byte[] bytes = null) + { + this.State = ConnectionState.Connecting; + + try + { + if (IPMode == IPMode.IPv4) + socket.Bind(new IPEndPoint(IPAddress.Any, 0)); + else + socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0)); + } + catch (SocketException e) + { + this.State = ConnectionState.NotConnected; + throw new HazelException("A SocketException occurred while binding to the port.", e); + } + + try + { + StartListeningForData(); + } + catch (ObjectDisposedException) + { + // If the socket's been disposed then we can just end there but make sure we're in NotConnected state. + // If we end up here I'm really lost... + this.State = ConnectionState.NotConnected; + return; + } + catch (SocketException e) + { + Dispose(); + throw new HazelException("A SocketException occurred while initiating a receive operation.", e); + } + + // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present) + // When acknowledged set the state to connected + SendHello(bytes, () => + { + this.State = ConnectionState.Connected; + this.InitializeKeepAliveTimer(); + }); + } + + /// <summary> + /// Instructs the listener to begin listening. + /// </summary> + void StartListeningForData() + { +#if DEBUG + if (this.TestLagMs > 0) + { + Thread.Sleep(this.TestLagMs); + } +#endif + + var msg = MessageReader.GetSized(this.ReceiveBufferSize); + try + { + socket.BeginReceive(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ReadCallback, msg); + } + catch + { + msg.Recycle(); + this.Dispose(); + } + } + + protected override void SetState(ConnectionState state) + { + try + { + // If the server disconnects you during the hello + // you can go straight from Connecting to NotConnected. + if (state == ConnectionState.Connected + || state == ConnectionState.NotConnected) + { + connectWaitLock.Set(); + } + else + { + connectWaitLock.Reset(); + } + } + catch (ObjectDisposedException) + { + } + } + + /// <summary> + /// Blocks until the Connection is connected. + /// </summary> + /// <param name="timeout">The number of milliseconds to wait before timing out.</param> + public bool WaitOnConnect(int timeout) + { + return connectWaitLock.WaitOne(timeout); + } + + /// <summary> + /// Called when data has been received by the socket. + /// </summary> + /// <param name="result">The asyncronous operation's result.</param> + void ReadCallback(IAsyncResult result) + { + var msg = (MessageReader)result.AsyncState; + + try + { + msg.Length = socket.EndReceive(result); + } + catch (SocketException e) + { + msg.Recycle(); + DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message); + return; + } + catch (Exception) + { + msg.Recycle(); + return; + } + + //Exit if no bytes read, we've failed. + if (msg.Length == 0) + { + msg.Recycle(); + DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes"); + return; + } + + //Begin receiving again + try + { + StartListeningForData(); + } + catch (SocketException e) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message); + } + catch (ObjectDisposedException) + { + //If the socket's been disposed then we can just end there. + return; + } + +#if DEBUG + if (this.TestDropRate > 0) + { + if ((this.testDropCount++ % this.TestDropRate) == 0) + { + return; + } + } + + DataReceivedRaw?.Invoke(msg.Buffer, msg.Length); +#endif + HandleReceive(msg, msg.Length); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// You may include optional disconnect data. The SendOption must be unreliable. + /// </summary> + protected override bool SendDisconnect(MessageWriter data = null) + { + lock (this) + { + if (this._state == ConnectionState.NotConnected) return false; + this.State = ConnectionState.NotConnected; // Use the property so we release the state lock + } + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + socket.SendTo( + bytes, + 0, + bytes.Length, + SocketFlags.None, + EndPoint); + } + catch { } + + return true; + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + if (disposing) + { + SendDisconnect(); + } + + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + + this.reliablePacketTimer.Dispose(); + this.connectWaitLock.Dispose(); + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs new file mode 100644 index 0000000..75b4f1d --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Threading; + + +namespace Hazel.Udp +{ + partial class UdpConnection + { + + /// <summary> + /// Class to hold packet data + /// </summary> + public class PingPacket : IRecyclable + { + private static readonly ObjectPool<PingPacket> PacketPool = new ObjectPool<PingPacket>(() => new PingPacket()); + + public readonly Stopwatch Stopwatch = new Stopwatch(); + + internal static PingPacket GetObject() + { + return PacketPool.GetObject(); + } + + public void Recycle() + { + Stopwatch.Stop(); + PacketPool.PutObject(this); + } + } + + internal ConcurrentDictionary<ushort, PingPacket> activePingPackets = new ConcurrentDictionary<ushort, PingPacket>(); + + /// <summary> + /// The interval from data being received or transmitted to a keepalive packet being sent in milliseconds. + /// </summary> + /// <remarks> + /// <para> + /// Keepalive packets serve to close connections when an endpoint abruptly disconnects and to ensure than any + /// NAT devices do not close their translation for our argument. By ensuring there is regular contact the + /// connection can detect and prevent these issues. + /// </para> + /// <para> + /// The default value is 10 seconds, set to System.Threading.Timeout.Infinite to disable keepalive packets. + /// </para> + /// </remarks> + public int KeepAliveInterval + { + get + { + return keepAliveInterval; + } + + set + { + keepAliveInterval = value; + ResetKeepAliveTimer(); + } + } + private int keepAliveInterval = 1500; + + public int MissingPingsUntilDisconnect { get; set; } = 6; + private volatile int pingsSinceAck = 0; + + /// <summary> + /// The timer creating keepalive pulses. + /// </summary> + private Timer keepAliveTimer; + + /// <summary> + /// Starts the keepalive timer. + /// </summary> + protected void InitializeKeepAliveTimer() + { + keepAliveTimer = new Timer( + HandleKeepAlive, + null, + keepAliveInterval, + keepAliveInterval + ); + } + + private void HandleKeepAlive(object state) + { + if (this.State != ConnectionState.Connected) return; + + if (this.pingsSinceAck >= this.MissingPingsUntilDisconnect) + { + this.DisposeKeepAliveTimer(); + this.DisconnectInternal(HazelInternalErrors.PingsWithoutResponse, $"Sent {this.pingsSinceAck} pings that remote has not responded to."); + return; + } + + try + { + this.pingsSinceAck++; + SendPing(); + } + catch + { + } + } + + // Pings are special, quasi-reliable packets. + // We send them to trigger responses that validate our connection is alive + // An unacked ping should never be the sole cause of a disconnect. + // Rather, the responses will reset our pingsSinceAck, enough unacked + // pings should cause a disconnect. + private void SendPing() + { + ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated); + + byte[] bytes = new byte[3]; + bytes[0] = (byte)UdpSendOption.Ping; + bytes[1] = (byte)(id >> 8); + bytes[2] = (byte)id; + + PingPacket pkt; + if (!this.activePingPackets.TryGetValue(id, out pkt)) + { + pkt = PingPacket.GetObject(); + if (!this.activePingPackets.TryAdd(id, pkt)) + { + throw new Exception("This shouldn't be possible"); + } + } + + pkt.Stopwatch.Restart(); + + WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogReliableSend(0); + } + + /// <summary> + /// Resets the keepalive timer to zero. + /// </summary> + protected void ResetKeepAliveTimer() + { + try + { + keepAliveTimer?.Change(keepAliveInterval, keepAliveInterval); + } + catch { } + } + + /// <summary> + /// Disposes of the keep alive timer. + /// </summary> + private void DisposeKeepAliveTimer() + { + if (this.keepAliveTimer != null) + { + this.keepAliveTimer.Dispose(); + } + + foreach (var kvp in activePingPackets) + { + if (this.activePingPackets.TryRemove(kvp.Key, out var pkt)) + { + pkt.Recycle(); + } + } + } + } +}
\ No newline at end of file diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs new file mode 100644 index 0000000..bed4738 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs @@ -0,0 +1,490 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; + +namespace Hazel.Udp +{ + partial class UdpConnection + { + private const int MinResendDelayMs = 50; + private const int MaxInitialResendDelayMs = 300; + private const int MaxAdditionalResendDelayMs = 1000; + + public readonly ObjectPool<Packet> PacketPool; + + /// <summary> + /// The starting timeout, in miliseconds, at which data will be resent. + /// </summary> + /// <remarks> + /// <para> + /// For reliable delivery data is resent at specified intervals unless an acknowledgement is received from the + /// receiving device. The ResendTimeout specifies the interval between the packets being resent, each time a packet + /// is resent the interval is increased for that packet until the duration exceeds the <see cref="DisconnectTimeoutMs"/> value. + /// </para> + /// <para> + /// Setting this to its default of 0 will mean the timeout is 2 times the value of the average ping, usually + /// resulting in a more dynamic resend that responds to endpoints on slower or faster connections. + /// </para> + /// </remarks> + public volatile int ResendTimeoutMs = 0; + + /// <summary> + /// Max number of times to resend. 0 == no limit + /// </summary> + public volatile int ResendLimit = 0; + + /// <summary> + /// A compounding multiplier to back off resend timeout. + /// Applied to ping before first timeout when ResendTimeout == 0. + /// </summary> + public volatile float ResendPingMultiplier = 2; + + /// <summary> + /// Holds the last ID allocated. + /// </summary> + private int lastIDAllocated = -1; + + /// <summary> + /// The packets of data that have been transmitted reliably and not acknowledged. + /// </summary> + internal ConcurrentDictionary<ushort, Packet> reliableDataPacketsSent = new ConcurrentDictionary<ushort, Packet>(); + + /// <summary> + /// Packet ids that have not been received, but are expected. + /// </summary> + private HashSet<ushort> reliableDataPacketsMissing = new HashSet<ushort>(); + + /// <summary> + /// The packet id that was received last. + /// </summary> + protected volatile ushort reliableReceiveLast = ushort.MaxValue; + + private object PingLock = new object(); + + /// <summary> + /// Returns the average ping to this endpoint. + /// </summary> + /// <remarks> + /// This returns the average ping for a one-way trip as calculated from the reliable packets that have been sent + /// and acknowledged by the endpoint. + /// </remarks> + private float _pingMs = 500; + + /// <summary> + /// The maximum times a message should be resent before marking the endpoint as disconnected. + /// </summary> + /// <remarks> + /// Reliable packets will be resent at an interval defined in <see cref="ResendTimeoutMs"/> for the number of times + /// specified here. Once a packet has been retransmitted this number of times and has not been acknowledged the + /// connection will be marked as disconnected and the <see cref="Connection.Disconnected">Disconnected</see> event + /// will be invoked. + /// </remarks> + public volatile int DisconnectTimeoutMs = 5000; + + /// <summary> + /// Class to hold packet data + /// </summary> + public class Packet : IRecyclable + { + public ushort Id; + private byte[] Data; + private readonly UdpConnection Connection; + private int Length; + + public int NextTimeoutMs; + public volatile bool Acknowledged; + + public Action AckCallback; + + public int Retransmissions; + public Stopwatch Stopwatch = new Stopwatch(); + + internal Packet(UdpConnection connection) + { + this.Connection = connection; + } + + internal void Set(ushort id, byte[] data, int length, int timeout, Action ackCallback) + { + this.Id = id; + this.Data = data; + this.Length = length; + + this.Acknowledged = false; + this.NextTimeoutMs = timeout; + this.AckCallback = ackCallback; + this.Retransmissions = 0; + + this.Stopwatch.Restart(); + } + + // Packets resent + public int Resend() + { + var connection = this.Connection; + if (!this.Acknowledged && connection != null) + { + long lifetimeMs = this.Stopwatch.ElapsedMilliseconds; + if (lifetimeMs >= connection.DisconnectTimeoutMs) + { + if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self)) + { + connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {lifetimeMs}ms ({self.Retransmissions} resends)"); + + self.Recycle(); + } + + return 0; + } + + if (lifetimeMs >= this.NextTimeoutMs) + { + ++this.Retransmissions; + if (connection.ResendLimit != 0 + && this.Retransmissions > connection.ResendLimit) + { + if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self)) + { + connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {self.Retransmissions} resends ({lifetimeMs}ms)"); + + self.Recycle(); + } + + return 0; + } + + this.NextTimeoutMs += (int)Math.Min(this.NextTimeoutMs * connection.ResendPingMultiplier, MaxAdditionalResendDelayMs); + try + { + connection.WriteBytesToConnection(this.Data, this.Length); + connection.Statistics.LogMessageResent(); + return 1; + } + catch (InvalidOperationException) + { + connection.DisconnectInternal(HazelInternalErrors.ConnectionDisconnected, "Could not resend data as connection is no longer connected"); + } + } + } + + return 0; + } + + /// <summary> + /// Returns this object back to the object pool from whence it came. + /// </summary> + public void Recycle() + { + this.Acknowledged = true; + + this.Connection.PacketPool.PutObject(this); + } + } + + internal int ManageReliablePackets() + { + int output = 0; + if (this.reliableDataPacketsSent.Count > 0) + { + foreach (var kvp in this.reliableDataPacketsSent) + { + Packet pkt = kvp.Value; + + try + { + output += pkt.Resend(); + } + catch { } + } + } + + return output; + } + + /// <summary> + /// Adds a 2 byte ID to the packet at offset and stores the packet reference for retransmission. + /// </summary> + /// <param name="buffer">The buffer to attach to.</param> + /// <param name="offset">The offset to attach at.</param> + /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param> + protected void AttachReliableID(byte[] buffer, int offset, Action ackCallback = null) + { + ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated); + + buffer[offset] = (byte)(id >> 8); + buffer[offset + 1] = (byte)id; + + int resendDelayMs = this.ResendTimeoutMs; + if (resendDelayMs <= 0) + { + resendDelayMs = (_pingMs * this.ResendPingMultiplier).ClampToInt(MinResendDelayMs, MaxInitialResendDelayMs); + } + + Packet packet = this.PacketPool.GetObject(); + packet.Set( + id, + buffer, + buffer.Length, + resendDelayMs, + ackCallback); + + if (!reliableDataPacketsSent.TryAdd(id, packet)) + { + throw new Exception("That shouldn't be possible"); + } + } + + public static int ClampToInt(float value, int min, int max) + { + if (value < min) return min; + if (value > max) return max; + return (int)value; + } + + /// <summary> + /// Sends the bytes reliably and stores the send. + /// </summary> + /// <param name="sendOption"></param> + /// <param name="data">The byte array to write to.</param> + /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param> + private void ReliableSend(byte sendOption, byte[] data, Action ackCallback = null) + { + //Inform keepalive not to send for a while + ResetKeepAliveTimer(); + + byte[] bytes = new byte[data.Length + 3]; + + //Add message type + bytes[0] = sendOption; + + //Add reliable ID + AttachReliableID(bytes, 1, ackCallback); + + //Copy data into new array + Buffer.BlockCopy(data, 0, bytes, bytes.Length - data.Length, data.Length); + + //Write to connection + WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogReliableSend(data.Length); + } + + /// <summary> + /// Handles a reliable message being received and invokes the data event. + /// </summary> + /// <param name="message">The buffer received.</param> + private void ReliableMessageReceive(MessageReader message, int bytesReceived) + { + ushort id; + if (ProcessReliableReceive(message.Buffer, 1, out id)) + { + InvokeDataReceived(SendOption.Reliable, message, 3, bytesReceived); + } + else + { + message.Recycle(); + } + + Statistics.LogReliableReceive(message.Length - 3, message.Length); + } + + /// <summary> + /// Handles receives from reliable packets. + /// </summary> + /// <param name="bytes">The buffer containing the data.</param> + /// <param name="offset">The offset of the reliable header.</param> + /// <returns>Whether the packet was a new packet or not.</returns> + private bool ProcessReliableReceive(byte[] bytes, int offset, out ushort id) + { + byte b1 = bytes[offset]; + byte b2 = bytes[offset + 1]; + + //Get the ID form the packet + id = (ushort)((b1 << 8) + b2); + + /* + * It gets a little complicated here (note the fact I'm actually using a multiline comment for once...) + * + * In a simple world if our data is greater than the last reliable packet received (reliableReceiveLast) + * then it is guaranteed to be a new packet, if it's not we can see if we are missing that packet (lookup + * in reliableDataPacketsMissing). + * + * --------rrl############# (1) + * + * (where --- are packets received already and #### are packets that will be counted as new) + * + * Unfortunately if id becomes greater than 65535 it will loop back to zero so we will add a pointer that + * specifies any packets with an id behind it are also new (overwritePointer). + * + * ####op----------rrl##### (2) + * + * ------rll#########op---- (3) + * + * Anything behind than the reliableReceiveLast pointer (but greater than the overwritePointer is either a + * missing packet or something we've already received so when we change the pointers we need to make sure + * we keep note of what hasn't been received yet (reliableDataPacketsMissing). + * + * So... + */ + + bool result = true; + + lock (reliableDataPacketsMissing) + { + //Calculate overwritePointer + ushort overwritePointer = (ushort)(reliableReceiveLast - 32768); + + //Calculate if it is a new packet by examining if it is within the range + bool isNew; + if (overwritePointer < reliableReceiveLast) + isNew = id > reliableReceiveLast || id <= overwritePointer; //Figure (2) + else + isNew = id > reliableReceiveLast && id <= overwritePointer; //Figure (3) + + //If it's new or we've not received anything yet + if (isNew) + { + // Mark items between the most recent receive and the id received as missing + if (id > reliableReceiveLast) + { + for (ushort i = (ushort)(reliableReceiveLast + 1); i < id; i++) + { + reliableDataPacketsMissing.Add(i); + } + } + else + { + int cnt = (ushort.MaxValue - reliableReceiveLast) + id; + for (ushort i = 1; i <= cnt; ++i) + { + reliableDataPacketsMissing.Add((ushort)(i + reliableReceiveLast)); + } + } + + //Update the most recently received + reliableReceiveLast = id; + } + + //Else it could be a missing packet + else + { + //See if we're missing it, else this packet is a duplicate as so we return false + if (!reliableDataPacketsMissing.Remove(id)) + { + result = false; + } + } + } + + // Send an acknowledgement + SendAck(id); + + return result; + } + + /// <summary> + /// Handles acknowledgement packets to us. + /// </summary> + /// <param name="bytes">The buffer containing the data.</param> + private void AcknowledgementMessageReceive(byte[] bytes, int bytesReceived) + { + this.pingsSinceAck = 0; + + ushort id = (ushort)((bytes[1] << 8) + bytes[2]); + AcknowledgeMessageId(id); + + if (bytesReceived == 4) + { + byte recentPackets = bytes[3]; + for (int i = 1; i <= 8; ++i) + { + if ((recentPackets & 1) != 0) + { + AcknowledgeMessageId((ushort)(id - i)); + } + + recentPackets >>= 1; + } + } + + Statistics.LogAcknowledgementReceive(bytesReceived); + } + + private void AcknowledgeMessageId(ushort id) + { + // Dispose of timer and remove from dictionary + if (reliableDataPacketsSent.TryRemove(id, out Packet packet)) + { + this.Statistics.LogReliablePacketAcknowledged(); + float rt = packet.Stopwatch.ElapsedMilliseconds; + + packet.AckCallback?.Invoke(); + packet.Recycle(); + + lock (PingLock) + { + this._pingMs = this._pingMs * .7f + rt * .3f; + } + } + else if (this.activePingPackets.TryRemove(id, out PingPacket pingPkt)) + { + this.Statistics.LogReliablePacketAcknowledged(); + float rt = pingPkt.Stopwatch.ElapsedMilliseconds; + + pingPkt.Recycle(); + + lock (PingLock) + { + this._pingMs = this._pingMs * .7f + rt * .3f; + } + } + } + + /// <summary> + /// Sends an acknowledgement for a packet given its identification bytes. + /// </summary> + /// <param name="byte1">The first identification byte.</param> + /// <param name="byte2">The second identification byte.</param> + private void SendAck(ushort id) + { + byte recentPackets = 0; + lock (this.reliableDataPacketsMissing) + { + for (int i = 1; i <= 8; ++i) + { + if (!this.reliableDataPacketsMissing.Contains((ushort)(id - i))) + { + recentPackets |= (byte)(1 << (i - 1)); + } + } + } + + byte[] bytes = new byte[] + { + (byte)UdpSendOption.Acknowledgement, + (byte)(id >> 8), + (byte)(id >> 0), + recentPackets + }; + + try + { + WriteBytesToConnection(bytes, bytes.Length); + } + catch (InvalidOperationException) { } + } + + private void DisposeReliablePackets() + { + foreach (var kvp in reliableDataPacketsSent) + { + if (this.reliableDataPacketsSent.TryRemove(kvp.Key, out var pkt)) + { + pkt.Recycle(); + } + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs new file mode 100644 index 0000000..e64576a --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs @@ -0,0 +1,259 @@ +using System; +using System.Net.Sockets; + +namespace Hazel.Udp +{ + /// <summary> + /// Represents a connection that uses the UDP protocol. + /// </summary> + /// <inheritdoc /> + public abstract partial class UdpConnection : NetworkConnection + { + public static readonly byte[] EmptyDisconnectBytes = new byte[] { (byte)UdpSendOption.Disconnect }; + + public override float AveragePingMs => this._pingMs; + protected readonly ILogger logger; + + + public UdpConnection(ILogger logger) : base() + { + this.logger = logger; + this.PacketPool = new ObjectPool<Packet>(() => new Packet(this)); + } + + internal static Socket CreateSocket(IPMode ipMode) + { + Socket socket; + if (ipMode == IPMode.IPv4) + { + socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + } + else + { + if (!Socket.OSSupportsIPv6) + throw new InvalidOperationException("IPV6 not supported!"); + + socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp); + socket.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, false); + } + + try + { + socket.DontFragment = false; + } + catch { } + + try + { + const int SIO_UDP_CONNRESET = -1744830452; + socket.IOControl(SIO_UDP_CONNRESET, new byte[1], null); + } + catch { } // Only necessary on Windows + + return socket; + } + + /// <summary> + /// Writes the given bytes to the connection. + /// </summary> + /// <param name="bytes">The bytes to write.</param> + protected abstract void WriteBytesToConnection(byte[] bytes, int length); + + /// <inheritdoc/> + public override SendErrors Send(MessageWriter msg) + { + if (this._state != ConnectionState.Connected) + { + return SendErrors.Disconnected; + } + + try + { + byte[] buffer = new byte[msg.Length]; + Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length); + + switch (msg.SendOption) + { + case SendOption.Reliable: + ResetKeepAliveTimer(); + + AttachReliableID(buffer, 1); + WriteBytesToConnection(buffer, buffer.Length); + Statistics.LogReliableSend(buffer.Length - 3); + break; + + default: + WriteBytesToConnection(buffer, buffer.Length); + Statistics.LogUnreliableSend(buffer.Length - 1); + break; + } + } + catch (Exception e) + { + this.logger?.WriteError("Unknown exception while sending: " + e); + return SendErrors.Unknown; + } + + return SendErrors.None; + } + + /// <summary> + /// Handles the reliable/fragmented sending from this connection. + /// </summary> + /// <param name="data">The data being sent.</param> + /// <param name="sendOption">The <see cref="SendOption"/> specified as its byte value.</param> + /// <param name="ackCallback">The callback to invoke when this packet is acknowledged.</param> + /// <returns>The bytes that should actually be sent.</returns> + protected virtual void HandleSend(byte[] data, byte sendOption, Action ackCallback = null) + { + switch (sendOption) + { + case (byte)UdpSendOption.Ping: + case (byte)SendOption.Reliable: + case (byte)UdpSendOption.Hello: + ReliableSend(sendOption, data, ackCallback); + break; + + //Treat all else as unreliable + default: + UnreliableSend(sendOption, data); + break; + } + } + + /// <summary> + /// Handles the receiving of data. + /// </summary> + /// <param name="message">The buffer containing the bytes received.</param> + protected internal virtual void HandleReceive(MessageReader message, int bytesReceived) + { + ushort id; + switch (message.Buffer[0]) + { + //Handle reliable receives + case (byte)SendOption.Reliable: + ReliableMessageReceive(message, bytesReceived); + break; + + //Handle acknowledgments + case (byte)UdpSendOption.Acknowledgement: + AcknowledgementMessageReceive(message.Buffer, bytesReceived); + message.Recycle(); + break; + + //We need to acknowledge hello and ping messages but dont want to invoke any events! + case (byte)UdpSendOption.Ping: + ProcessReliableReceive(message.Buffer, 1, out id); + Statistics.LogHelloReceive(bytesReceived); + message.Recycle(); + break; + case (byte)UdpSendOption.Hello: + ProcessReliableReceive(message.Buffer, 1, out id); + Statistics.LogHelloReceive(bytesReceived); + message.Recycle(); + break; + + case (byte)UdpSendOption.Disconnect: + message.Offset = 1; + message.Position = 0; + DisconnectRemote("The remote sent a disconnect request", message); + message.Recycle(); + break; + + case (byte)SendOption.None: + InvokeDataReceived(SendOption.None, message, 1, bytesReceived); + Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived); + break; + + // Treat everything else as garbage + default: + message.Recycle(); + + // TODO: A new stat for unused data + Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived); + break; + } + } + + /// <summary> + /// Sends bytes using the unreliable UDP protocol. + /// </summary> + /// <param name="sendOption">The SendOption to attach.</param> + /// <param name="data">The data.</param> + void UnreliableSend(byte sendOption, byte[] data) + { + this.UnreliableSend(sendOption, data, 0, data.Length); + } + + /// <summary> + /// Sends bytes using the unreliable UDP protocol. + /// </summary> + /// <param name="data">The data.</param> + /// <param name="sendOption">The SendOption to attach.</param> + /// <param name="offset"></param> + /// <param name="length"></param> + void UnreliableSend(byte sendOption, byte[] data, int offset, int length) + { + byte[] bytes = new byte[length + 1]; + + //Add message type + bytes[0] = sendOption; + + //Copy data into new array + Buffer.BlockCopy(data, offset, bytes, bytes.Length - length, length); + + //Write to connection + WriteBytesToConnection(bytes, bytes.Length); + + Statistics.LogUnreliableSend(length); + } + + /// <summary> + /// Helper method to invoke the data received event. + /// </summary> + /// <param name="sendOption">The send option the message was received with.</param> + /// <param name="buffer">The buffer received.</param> + /// <param name="dataOffset">The offset of data in the buffer.</param> + void InvokeDataReceived(SendOption sendOption, MessageReader buffer, int dataOffset, int bytesReceived) + { + buffer.Offset = dataOffset; + buffer.Length = bytesReceived - dataOffset; + buffer.Position = 0; + + InvokeDataReceived(buffer, sendOption); + } + + /// <summary> + /// Sends a hello packet to the remote endpoint. + /// </summary> + /// <param name="acknowledgeCallback">The callback to invoke when the hello packet is acknowledged.</param> + protected void SendHello(byte[] bytes, Action acknowledgeCallback) + { + //First byte of handshake is version indicator so add data after + byte[] actualBytes; + if (bytes == null) + { + actualBytes = new byte[1]; + } + else + { + actualBytes = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, actualBytes, 1, bytes.Length); + } + + HandleSend(actualBytes, (byte)UdpSendOption.Hello, acknowledgeCallback); + } + + /// <inheritdoc/> + protected override void Dispose(bool disposing) + { + if (disposing) + { + DisposeKeepAliveTimer(); + DisposeReliablePackets(); + } + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs new file mode 100644 index 0000000..c017a0f --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs @@ -0,0 +1,339 @@ +using System; +using System.Collections.Concurrent; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading; + +namespace Hazel.Udp +{ + /// <summary> + /// Listens for new UDP connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class UdpConnectionListener : NetworkConnectionListener + { + private const int SendReceiveBufferSize = 1024 * 1024; + private const int BufferSize = ushort.MaxValue; + + private Socket socket; + private ILogger Logger; + private Timer reliablePacketTimer; + + private ConcurrentDictionary<EndPoint, UdpServerConnection> allConnections = new ConcurrentDictionary<EndPoint, UdpServerConnection>(); + + public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count; + public override int ConnectionCount { get { return this.allConnections.Count; } } + public override int ReceiveQueueLength => throw new NotImplementedException(); + public override int SendQueueLength => throw new NotImplementedException(); + + /// <summary> + /// Creates a new UdpConnectionListener for the given <see cref="IPAddress"/>, port and <see cref="IPMode"/>. + /// </summary> + /// <param name="endPoint">The endpoint to listen on.</param> + public UdpConnectionListener(IPEndPoint endPoint, IPMode ipMode = IPMode.IPv4, ILogger logger = null) + { + this.Logger = logger; + this.EndPoint = endPoint; + this.IPMode = ipMode; + + this.socket = UdpConnection.CreateSocket(this.IPMode); + + socket.ReceiveBufferSize = SendReceiveBufferSize; + socket.SendBufferSize = SendReceiveBufferSize; + + reliablePacketTimer = new Timer(ManageReliablePackets, null, 100, Timeout.Infinite); + } + + ~UdpConnectionListener() + { + this.Dispose(false); + } + + private void ManageReliablePackets(object state) + { + foreach (var kvp in this.allConnections) + { + var sock = kvp.Value; + sock.ManageReliablePackets(); + } + + try + { + this.reliablePacketTimer.Change(100, Timeout.Infinite); + } + catch { } + } + + /// <inheritdoc /> + public override void Start() + { + try + { + socket.Bind(EndPoint); + } + catch (SocketException e) + { + throw new HazelException("Could not start listening as a SocketException occurred", e); + } + + StartListeningForData(); + } + + /// <summary> + /// Instructs the listener to begin listening. + /// </summary> + private void StartListeningForData() + { + EndPoint remoteEP = EndPoint; + + MessageReader message = null; + try + { + message = MessageReader.GetSized(this.ReceiveBufferSize); + socket.BeginReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP, ReadCallback, message); + } + catch (SocketException sx) + { + message?.Recycle(); + + this.Logger?.WriteError("Socket Ex in StartListening: " + sx.Message); + + Thread.Sleep(10); + StartListeningForData(); + return; + } + catch (Exception ex) + { + message.Recycle(); + this.Logger?.WriteError("Stopped due to: " + ex.Message); + return; + } + } + + void ReadCallback(IAsyncResult result) + { + var message = (MessageReader)result.AsyncState; + int bytesReceived; + EndPoint remoteEndPoint = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port); + + //End the receive operation + try + { + bytesReceived = socket.EndReceiveFrom(result, ref remoteEndPoint); + + message.Offset = 0; + message.Length = bytesReceived; + } + catch (ObjectDisposedException) + { + message.Recycle(); + return; + } + catch (SocketException sx) + { + message.Recycle(); + if (sx.SocketErrorCode == SocketError.NotConnected) + { + this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected); + return; + } + + // Client no longer reachable, pretend it didn't happen + // TODO should this not inform the connection this client is lost??? + + // This thread suggests the IP is not passed out from WinSoc so maybe not possible + // http://stackoverflow.com/questions/2576926/python-socket-error-on-udp-data-receive-10054 + this.Logger?.WriteError($"Socket Ex {sx.SocketErrorCode} in ReadCallback: {sx.Message}"); + + Thread.Sleep(10); + StartListeningForData(); + return; + } + catch (Exception ex) + { + // Idk, maybe a null ref after dispose? + message.Recycle(); + this.Logger?.WriteError("Stopped due to: " + ex.Message); + return; + } + + // I'm a little concerned about a infinite loop here, but it seems like it's possible + // to get 0 bytes read on UDP without the socket being shut down. + if (bytesReceived == 0) + { + message.Recycle(); + this.Logger?.WriteInfo("Received 0 bytes"); + Thread.Sleep(10); + StartListeningForData(); + return; + } + + //Begin receiving again + StartListeningForData(); + + bool aware = true; + bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello; + + // If we're aware of this connection use the one already + // If this is a new client then connect with them! + UdpServerConnection connection; + if (!this.allConnections.TryGetValue(remoteEndPoint, out connection)) + { + lock (this.allConnections) + { + if (!this.allConnections.TryGetValue(remoteEndPoint, out connection)) + { + // Check for malformed connection attempts + if (!isHello) + { + message.Recycle(); + return; + } + + if (AcceptConnection != null) + { + if (!AcceptConnection((IPEndPoint)remoteEndPoint, message.Buffer, out var response)) + { + message.Recycle(); + if (response != null) + { + SendData(response, response.Length, remoteEndPoint); + } + + return; + } + } + + aware = false; + connection = new UdpServerConnection(this, (IPEndPoint)remoteEndPoint, this.IPMode, this.Logger); + if (!this.allConnections.TryAdd(remoteEndPoint, connection)) + { + throw new HazelException("Failed to add a connection. This should never happen."); + } + } + } + } + + // If it's a new connection invoke the NewConnection event. + // This needs to happen before handling the message because in localhost scenarios, the ACK and + // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers + if (!aware) + { + // Skip header and hello byte; + message.Offset = 4; + message.Length = bytesReceived - 4; + message.Position = 0; + InvokeNewConnection(message, connection); + } + + // Inform the connection of the buffer (new connections need to send an ack back to client) + connection.HandleReceive(message, bytesReceived); + } + +#if DEBUG + public int TestDropRate = -1; + private int dropCounter = 0; +#endif + + /// <summary> + /// Sends data from the listener socket. + /// </summary> + /// <param name="bytes">The bytes to send.</param> + /// <param name="endPoint">The endpoint to send to.</param> + internal void SendData(byte[] bytes, int length, EndPoint endPoint) + { + if (length > bytes.Length) return; + +#if DEBUG + if (TestDropRate > 0) + { + if (Interlocked.Increment(ref dropCounter) % TestDropRate == 0) + { + return; + } + } +#endif + + try + { + socket.BeginSendTo( + bytes, + 0, + length, + SocketFlags.None, + endPoint, + SendCallback, + null); + + this.Statistics.AddBytesSent(length); + } + catch (SocketException e) + { + this.Logger?.WriteError("Could not send data as a SocketException occurred: " + e); + } + catch (ObjectDisposedException) + { + //Keep alive timer probably ran, ignore + return; + } + } + + private void SendCallback(IAsyncResult result) + { + try + { + socket.EndSendTo(result); + } + catch { } + } + + /// <summary> + /// Sends data from the listener socket. + /// </summary> + /// <param name="bytes">The bytes to send.</param> + /// <param name="endPoint">The endpoint to send to.</param> + internal void SendDataSync(byte[] bytes, int length, EndPoint endPoint) + { + try + { + socket.SendTo( + bytes, + 0, + length, + SocketFlags.None, + endPoint + ); + + this.Statistics.AddBytesSent(length); + } + catch { } + } + + /// <summary> + /// Removes a virtual connection from the list. + /// </summary> + /// <param name="endPoint">The endpoint of the virtual connection.</param> + internal void RemoveConnectionTo(EndPoint endPoint) + { + this.allConnections.TryRemove(endPoint, out var conn); + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + foreach (var kvp in this.allConnections) + { + kvp.Value.Dispose(); + } + + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + + this.reliablePacketTimer.Dispose(); + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs new file mode 100644 index 0000000..ff5b29d --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs @@ -0,0 +1,108 @@ +using System; +using System.Net; + +namespace Hazel.Udp +{ + /// <summary> + /// Represents a servers's connection to a client that uses the UDP protocol. + /// </summary> + /// <inheritdoc/> + internal sealed class UdpServerConnection : UdpConnection + { + /// <summary> + /// The connection listener that we use the socket of. + /// </summary> + /// <remarks> + /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that + /// created this connection and is hence the listener this conenction sends and receives via. + /// </remarks> + public UdpConnectionListener Listener { get; private set; } + + /// <summary> + /// Creates a UdpConnection for the virtual connection to the endpoint. + /// </summary> + /// <param name="listener">The listener that created this connection.</param> + /// <param name="endPoint">The endpoint that we are connected to.</param> + /// <param name="IPMode">The IPMode we are connected using.</param> + internal UdpServerConnection(UdpConnectionListener listener, IPEndPoint endPoint, IPMode IPMode, ILogger logger) + : base(logger) + { + this.Listener = listener; + this.EndPoint = endPoint; + this.IPMode = IPMode; + + State = ConnectionState.Connected; + this.InitializeKeepAliveTimer(); + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { + this.Statistics.LogPacketSend(length); + Listener.SendData(bytes, length, EndPoint); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void Connect(byte[] bytes = null, int timeout = 5000) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <inheritdoc /> + /// <remarks> + /// This will always throw a HazelException. + /// </remarks> + public override void ConnectAsync(byte[] bytes = null) + { + throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?"); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// </summary> + protected override bool SendDisconnect(MessageWriter data = null) + { + lock (this) + { + if (this._state != ConnectionState.Connected) + { + return false; + } + + this._state = ConnectionState.NotConnected; + } + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + Listener.SendDataSync(bytes, bytes.Length, EndPoint); + } + catch { } + + return true; + } + + protected override void Dispose(bool disposing) + { + Listener.RemoveConnectionTo(EndPoint); + + if (disposing) + { + SendDisconnect(); + } + + base.Dispose(disposing); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs new file mode 100644 index 0000000..8e6063d --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs @@ -0,0 +1,353 @@ +using System; +using System.Net; +using System.Net.Sockets; +using System.Threading; + + +namespace Hazel.Udp +{ + /// <summary> + /// Unity doesn't always get along with thread pools well, so this interface will hopefully suit that case better. + /// </summary> + /// <inheritdoc/> + public class UnityUdpClientConnection : UdpConnection + { + /// <summary> + /// The max size Hazel attempts to read from the network. + /// Defaults to 8096. + /// </summary> + /// <remarks> + /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo. + /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5 + /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant + /// for transferring large contiguous blocks of data, so... please don't? + /// </remarks> + public int ReceiveBufferSize = 8096; + + private Socket socket; + + public UnityUdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) + : base(logger) + { + this.EndPoint = remoteEndPoint; + this.IPMode = ipMode; + + this.socket = CreateSocket(ipMode); + this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ExclusiveAddressUse, true); + } + + ~UnityUdpClientConnection() + { + this.Dispose(false); + } + + public void FixedUpdate() + { + try + { + ResendPacketsIfNeeded(); + } + catch (Exception e) + { + this.logger.WriteError("FixedUpdate: " + e); + } + + try + { + ManageReliablePackets(); + } + catch (Exception e) + { + this.logger.WriteError("FixedUpdate: " + e); + } + } + + protected virtual void RestartConnection() + { + } + + protected virtual void ResendPacketsIfNeeded() + { + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { +#if DEBUG + if (TestLagMs > 0) + { + ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); }); + } + else +#endif + { + WriteBytesToConnectionReal(bytes, length); + } + } + + private void WriteBytesToConnectionReal(byte[] bytes, int length) + { + try + { + this.Statistics.LogPacketSend(length); + socket.BeginSendTo( + bytes, + 0, + length, + SocketFlags.None, + EndPoint, + HandleSendTo, + null); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + /// <summary> + /// Synchronously writes the given bytes to the connection. + /// </summary> + /// <param name="bytes">The bytes to write.</param> + protected virtual void WriteBytesToConnectionSync(byte[] bytes, int length) + { + try + { + socket.SendTo( + bytes, + 0, + length, + SocketFlags.None, + EndPoint); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + private void HandleSendTo(IAsyncResult result) + { + try + { + socket.EndSendTo(result); + } + catch (NullReferenceException) { } + catch (ObjectDisposedException) + { + // Already disposed and disconnected... + } + catch (SocketException ex) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message); + } + } + + public override void Connect(byte[] bytes = null, int timeout = 5000) + { + this.ConnectAsync(bytes); + for (int timer = 0; timer < timeout; timer += 100) + { + if (this.State != ConnectionState.Connecting) return; + Thread.Sleep(100); + + // I guess if we're gonna block in Unity, then let's assume no one will pump this for us. + this.FixedUpdate(); + } + } + + /// <inheritdoc /> + public override void ConnectAsync(byte[] bytes = null) + { + this.State = ConnectionState.Connecting; + + try + { + if (IPMode == IPMode.IPv4) + socket.Bind(new IPEndPoint(IPAddress.Any, 0)); + else + socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0)); + } + catch (SocketException e) + { + this.State = ConnectionState.NotConnected; + throw new HazelException("A SocketException occurred while binding to the port.", e); + } + + this.RestartConnection(); + + try + { + StartListeningForData(); + } + catch (ObjectDisposedException) + { + // If the socket's been disposed then we can just end there but make sure we're in NotConnected state. + // If we end up here I'm really lost... + this.State = ConnectionState.NotConnected; + return; + } + catch (SocketException e) + { + Dispose(); + throw new HazelException("A SocketException occurred while initiating a receive operation.", e); + } + + // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present) + // When acknowledged set the state to connected + SendHello(bytes, () => + { + this.InitializeKeepAliveTimer(); + this.State = ConnectionState.Connected; + }); + } + + /// <summary> + /// Instructs the listener to begin listening. + /// </summary> + void StartListeningForData() + { + var msg = MessageReader.GetSized(this.ReceiveBufferSize); + try + { + EndPoint ep = this.EndPoint; + socket.BeginReceiveFrom(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ref ep, ReadCallback, msg); + } + catch + { + msg.Recycle(); + this.Dispose(); + } + } + + /// <summary> + /// Called when data has been received by the socket. + /// </summary> + /// <param name="result">The asyncronous operation's result.</param> + void ReadCallback(IAsyncResult result) + { +#if DEBUG + if (this.TestLagMs > 0) + { + Thread.Sleep(this.TestLagMs); + } +#endif + + var msg = (MessageReader)result.AsyncState; + + try + { + EndPoint ep = this.EndPoint; + msg.Length = socket.EndReceiveFrom(result, ref ep); + } + catch (SocketException e) + { + msg.Recycle(); + DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message); + return; + } + catch (ObjectDisposedException) + { + // Weirdly, it seems that this method can be called twice on the same AsyncState when object is disposed... + // So this just keeps us from hitting Duplicate Add errors at the risk of if this is a platform + // specific bug, we leak a MessageReader while the socket is disposing. Not a bad trade off. + return; + } + catch (Exception) + { + msg.Recycle(); + return; + } + + //Exit if no bytes read, we've failed. + if (msg.Length == 0) + { + msg.Recycle(); + DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes"); + return; + } + + //Begin receiving again + try + { + StartListeningForData(); + } + catch (SocketException e) + { + DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message); + } + catch (ObjectDisposedException) + { + //If the socket's been disposed then we can just end there. + return; + } + +#if DEBUG + if (this.TestDropRate > 0) + { + if ((this.testDropCount++ % this.TestDropRate) == 0) + { + return; + } + } +#endif + + HandleReceive(msg, msg.Length); + } + + /// <summary> + /// Sends a disconnect message to the end point. + /// You may include optional disconnect data. The SendOption must be unreliable. + /// </summary> + protected override bool SendDisconnect(MessageWriter data = null) + { + lock (this) + { + if (this._state == ConnectionState.NotConnected) return false; + this._state = ConnectionState.NotConnected; + } + + var bytes = EmptyDisconnectBytes; + if (data != null && data.Length > 0) + { + if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable."); + + bytes = data.ToByteArray(true); + bytes[0] = (byte)UdpSendOption.Disconnect; + } + + try + { + this.WriteBytesToConnectionSync(bytes, bytes.Length); + } + catch { } + + return true; + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + if (disposing) + { + SendDisconnect(); + } + + try { this.socket.Shutdown(SocketShutdown.Both); } catch { } + try { this.socket.Close(); } catch { } + try { this.socket.Dispose(); } catch { } + + base.Dispose(disposing); + } + } +} |