aboutsummaryrefslogtreecommitdiff
path: root/Tools/Hazel-Networking/Hazel
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/Hazel-Networking/Hazel')
-rw-r--r--Tools/Hazel-Networking/Hazel/ByteSpan.cs191
-rw-r--r--Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs131
-rw-r--r--Tools/Hazel-Networking/Hazel/Connection.cs234
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionListener.cs160
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionState.cs28
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs574
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs369
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/Const.cs82
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs36
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs49
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/IAes.cs27
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs86
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs36
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/X25519.cs844
-rw-r--r--Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs29
-rw-r--r--Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs24
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs147
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs1424
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs1246
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs734
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs63
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs66
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Record.cs123
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs159
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs202
-rw-r--r--Tools/Hazel-Networking/Hazel/Extensions.cs34
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs44
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs402
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs110
-rw-r--r--Tools/Hazel-Networking/Hazel/Hazel.csproj14
-rw-r--r--Tools/Hazel-Networking/Hazel/HazelException.cs24
-rw-r--r--Tools/Hazel-Networking/Hazel/IPMode.cs30
-rw-r--r--Tools/Hazel-Networking/Hazel/IRecyclable.cs29
-rw-r--r--Tools/Hazel-Networking/Hazel/ListenerStatistics.cs23
-rw-r--r--Tools/Hazel-Networking/Hazel/MessageReader.cs452
-rw-r--r--Tools/Hazel-Networking/Hazel/MessageWriter.cs365
-rw-r--r--Tools/Hazel-Networking/Hazel/NetworkConnection.cs117
-rw-r--r--Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs26
-rw-r--r--Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs22
-rw-r--r--Tools/Hazel-Networking/Hazel/ObjectPool.cs108
-rw-r--r--Tools/Hazel-Networking/Hazel/SendErrors.cs15
-rw-r--r--Tools/Hazel-Networking/Hazel/SendOption.cs35
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs65
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs158
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs347
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs39
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs157
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs127
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs364
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs167
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs490
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs259
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs339
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs108
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs353
57 files changed, 12055 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel/ByteSpan.cs b/Tools/Hazel-Networking/Hazel/ByteSpan.cs
new file mode 100644
index 0000000..7cfa3b5
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ByteSpan.cs
@@ -0,0 +1,191 @@
+using System;
+
+namespace Hazel
+{
+ /// <summary>
+ /// This is a minimal implementation of `System.Span` in .NET 5.0
+ /// </summary>
+ public struct ByteSpan
+ {
+ private readonly byte[] array_;
+
+ /// <summary>
+ /// Createa a new span object containing an entire array
+ /// </summary>
+ public ByteSpan(byte[] array)
+ {
+ if (array == null)
+ {
+ this.array_ = null;
+ this.Offset = 0;
+ this.Length = 0;
+ }
+ else
+ {
+ this.array_ = array;
+ this.Offset = 0;
+ this.Length = array.Length;
+ }
+ }
+
+ /// <summary>
+ /// Creates a new span object containing a subset of an array
+ /// </summary>
+ public ByteSpan(byte[] array, int offset, int length)
+ {
+ if (array == null)
+ {
+ if (offset != 0)
+ {
+ throw new ArgumentException("Invalid offset", nameof(offset));
+ }
+ if (length != 0)
+ {
+ throw new ArgumentException("Invalid length", nameof(offset));
+ }
+
+ this.array_ = null;
+ this.Offset = 0;
+ this.Length = 0;
+ }
+ else
+ {
+ if (offset < 0 || offset > array.Length)
+ {
+ throw new ArgumentException("Invalid offset", nameof(offset));
+ }
+ if (length < 0)
+ {
+ throw new ArgumentException($"Invalid length: {length}", nameof(length));
+ }
+ if ((offset + length) > array.Length)
+ {
+ throw new ArgumentException($"Invalid length. Length: {length} Offset: {offset} Array size: {array.Length}", nameof(length));
+ }
+
+ this.array_ = array;
+ this.Offset = offset;
+ this.Length = length;
+ }
+ }
+
+ /// <summary>
+ /// Returns the underlying array.
+ ///
+ /// WARNING: This does not return the span, but the entire underlying storage block
+ /// </summary>
+ public byte[] GetUnderlyingArray()
+ {
+ return this.array_;
+ }
+
+ /// <summary>
+ /// Returns the offset into the underlying array
+ /// </summary>
+ public int Offset { get; }
+
+ /// <summary>
+ /// Returns the length of the current span
+ /// </summary>
+ public int Length { get; }
+
+ /// <summary>
+ /// Gets the span element at the specified index
+ /// </summary>
+ public byte this[int index]
+ {
+ get
+ {
+ if (index < 0 || index >= this.Length)
+ {
+ throw new IndexOutOfRangeException();
+ }
+
+ return this.array_[this.Offset + index];
+ }
+ set
+ {
+ if (index < 0 || index >= this.Length)
+ {
+ throw new IndexOutOfRangeException();
+ }
+
+ this.array_[this.Offset + index] = value;
+ }
+ }
+
+ /// <summary>
+ /// Create a new span that is a subset of this span [offset, this.Length-offset)
+ /// </summary>
+ public ByteSpan Slice(int offset)
+ {
+ return Slice(offset, this.Length - offset);
+ }
+
+ /// <summary>
+ /// Create a new span that is a subset of this span [offset, length)
+ /// </summary>
+ public ByteSpan Slice(int offset, int length)
+ {
+ return new ByteSpan(this.array_, this.Offset + offset, length);
+ }
+
+ /// <summary>
+ /// Copies the contents of the span to an array
+ /// </summary>
+ public void CopyTo(byte[] array, int offset)
+ {
+ CopyTo(new ByteSpan(array, offset, array.Length - offset));
+ }
+
+ /// <summary>
+ /// Copies the contents of the span to another span
+ /// </summary>
+ public void CopyTo(ByteSpan destination)
+ {
+ if (destination.Length < this.Length)
+ {
+ throw new ArgumentException("Destination span is shorter than source", nameof(destination));
+ }
+
+ if (Length > 0)
+ {
+ Buffer.BlockCopy(this.array_, this.Offset, destination.array_, destination.Offset, this.Length);
+ }
+ }
+
+ /// <summary>
+ /// Create a new array with the contents of this span
+ /// </summary>
+ public byte[] ToArray()
+ {
+ byte[] result = new byte[Length];
+ CopyTo(result);
+ return result;
+ }
+
+ public override string ToString()
+ {
+ return string.Join(" ", this.ToArray());
+ }
+
+ /// <summary>
+ /// Implicit conversion from byte[] -> ByteSpan
+ /// </summary>
+ public static implicit operator ByteSpan(byte[] array)
+ {
+ return new ByteSpan(array);
+ }
+
+ /// <summary>
+ /// Retuns an empty span object
+ /// </summary>
+ public static ByteSpan Empty
+ {
+ get
+ {
+ return new ByteSpan(null);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs
new file mode 100644
index 0000000..3a9d1ac
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs
@@ -0,0 +1,131 @@
+namespace Hazel
+{
+ /// <summary>
+ /// Extension functions for (en/de)coding integer values
+ /// </summary>
+ public static class ByteSpanBigEndianExtensions
+ {
+ // Write a 16-bit integer in big-endian format to output[0..2)
+ public static void WriteBigEndian16(this ByteSpan output, ushort value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 8);
+ output[offset + 1] = (byte)(value >> 0);
+ }
+
+ // Write a 24-bit integer in big-endian format to output[0..3)
+ public static void WriteBigEndian24(this ByteSpan output, uint value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 16);
+ output[offset + 1] = (byte)(value >> 8);
+ output[offset + 2] = (byte)(value >> 0);
+ }
+
+ // Write a 32-bit integer in big-endian format to output[0..4)
+ public static void WriteBigEndian32(this ByteSpan output, uint value, int offset)
+ {
+ output[offset + 0] = (byte)(value >> 24);
+ output[offset + 1] = (byte)(value >> 16);
+ output[offset + 2] = (byte)(value >> 8);
+ output[offset + 3] = (byte)(value >> 0);
+ }
+
+ // Write a 48-bit integer in big-endian format to output[0..6)
+ public static void WriteBigEndian48(this ByteSpan output, ulong value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 40);
+ output[offset + 1] = (byte)(value >> 32);
+ output[offset + 2] = (byte)(value >> 24);
+ output[offset + 3] = (byte)(value >> 16);
+ output[offset + 4] = (byte)(value >> 8);
+ output[offset + 5] = (byte)(value >> 0);
+ }
+
+ // Write a 64-bit integer in big-endian format to output[0..8)
+ public static void WriteBigEndian64(this ByteSpan output, ulong value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 56);
+ output[offset + 1] = (byte)(value >> 48);
+ output[offset + 2] = (byte)(value >> 40);
+ output[offset + 3] = (byte)(value >> 32);
+ output[offset + 4] = (byte)(value >> 24);
+ output[offset + 5] = (byte)(value >> 16);
+ output[offset + 6] = (byte)(value >> 8);
+ output[offset + 7] = (byte)(value >> 0);
+ }
+
+ // Read a 16-bit integer in big-endian format from input[0..2)
+ public static ushort ReadBigEndian16(this ByteSpan input, int offset = 0)
+ {
+ ushort value = 0;
+ value |= (ushort)(input[offset + 0] << 8);
+ value |= (ushort)(input[offset + 1] << 0);
+ return value;
+ }
+
+ // Read a 24-bit integer in big-endian format from input[0..3)
+ public static uint ReadBigEndian24(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0] << 16;
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 0;
+ return value;
+ }
+
+ // Read a 48-bit integer in big-endian format from input[0..3)
+ public static ulong ReadBigEndian48(this ByteSpan input, int offset = 0)
+ {
+ ulong value = 0;
+ value |= (ulong)input[offset + 0] << 40;
+ value |= (ulong)input[offset + 1] << 32;
+ value |= (ulong)input[offset + 2] << 24;
+ value |= (ulong)input[offset + 3] << 16;
+ value |= (ulong)input[offset + 4] << 8;
+ value |= (ulong)input[offset + 5] << 0;
+ return value;
+ }
+ }
+
+ public static class ByteSpanLittleEndianExtensions
+ {
+ // Read a 24-bit integer in little-endian format from input[0..3)
+ public static uint ReadLittleEndian24(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0];
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 16;
+ return value;
+ }
+
+ // Read a 24-bit integer in little-endian format from input[0..4)
+ public static uint ReadLittleEndian32(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0];
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 16;
+ value |= (uint)input[offset + 3] << 24;
+ return value;
+ }
+
+ /// <summary>
+ /// Reuse an existing span if there is enough space,
+ /// otherwise allocate new storage
+ /// </summary>
+ /// <param name="source">
+ /// Source span we should attempt to reuse
+ /// </param>
+ /// <param name="requiredSize">Required size (bytes)</param>
+ public static ByteSpan ReuseSpanIfPossible(this ByteSpan source, int requiredSize)
+ {
+ if (source.Length >= requiredSize)
+ {
+ return source.Slice(0, requiredSize);
+ }
+
+ return new byte[requiredSize];
+ }
+
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Connection.cs b/Tools/Hazel-Networking/Hazel/Connection.cs
new file mode 100644
index 0000000..da2f59a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Connection.cs
@@ -0,0 +1,234 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Net.Sockets;
+using System.Net;
+using System.Threading;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Base class for all connections.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Connection is the base class for all connections that Hazel can make. It provides common functionality and a
+ /// standard interface to allow connections to be swapped easily.
+ /// </para>
+ /// <para>
+ /// Any class inheriting from Connection should provide the 3 standard guarantees that Hazel provides:
+ /// <list type="bullet">
+ /// <item>
+ /// <description>Thread Safe</description>
+ /// </item>
+ /// <item>
+ /// <description>Connection Orientated</description>
+ /// </item>
+ /// <item>
+ /// <description>Packet/Message Based</description>
+ /// </item>
+ /// </list>
+ /// </para>
+ /// </remarks>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class Connection : IDisposable
+ {
+ /// <summary>
+ /// Called when a message has been received.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// DataReceived is invoked everytime a message is received from the end point of this connection, the message
+ /// that was received can be found in the <see cref="DataReceivedEventArgs"/> alongside other information from the
+ /// event.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpClientExample.cs"/>
+ /// </example>
+ public event Action<DataReceivedEventArgs> DataReceived;
+
+ public int TestLagMs = -1;
+ public int TestDropRate = 0;
+ protected int testDropCount = 0;
+
+ /// <summary>
+ /// Called when the end point disconnects or an error occurs.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Disconnected is invoked when the connection is closed due to an exception occuring or because the remote
+ /// end point disconnected. If it was invoked due to an exception occuring then the exception is available
+ /// in the <see cref="DisconnectedEventArgs"/> passed with the event.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpClientExample.cs"/>
+ /// </example>
+ public event EventHandler<DisconnectedEventArgs> Disconnected;
+
+ /// <summary>
+ /// The remote end point of this Connection.
+ /// </summary>
+ /// <remarks>
+ /// This is the end point that this connection is connected to (i.e. the other device). This returns an abstract
+ /// <see cref="ConnectionEndPoint"/> which can then be cast to an appropriate end point depending on the
+ /// connection type.
+ /// </remarks>
+ public IPEndPoint EndPoint { get; protected set; }
+
+ public IPMode IPMode { get; protected set; }
+
+ /// <summary>
+ /// The traffic statistics about this Connection.
+ /// </summary>
+ /// <remarks>
+ /// Contains statistics about the number of messages and bytes sent and received by this connection.
+ /// </remarks>
+ public ConnectionStatistics Statistics { get; protected set; }
+
+ /// <summary>
+ /// The state of this connection.
+ /// </summary>
+ /// <remarks>
+ /// All implementers should be aware that when this is set to ConnectionState.Connected it will
+ /// release all threads that are blocked on <see cref="WaitOnConnect"/>.
+ /// </remarks>
+ public ConnectionState State
+ {
+ get
+ {
+ return this._state;
+ }
+
+ protected set
+ {
+ this._state = value;
+ this.SetState(value);
+ }
+ }
+
+ protected ConnectionState _state;
+ protected virtual void SetState(ConnectionState state) { }
+
+ /// <summary>
+ /// Constructor that initializes the ConnecitonStatistics object.
+ /// </summary>
+ /// <remarks>
+ /// This constructor initialises <see cref="Statistics"/> with empty statistics and sets <see cref="State"/> to
+ /// <see cref="ConnectionState.NotConnected"/>.
+ /// </remarks>
+ protected Connection()
+ {
+ this.Statistics = new ConnectionStatistics();
+ this.State = ConnectionState.NotConnected;
+ }
+
+ /// <summary>
+ /// Sends a number of bytes to the end point of the connection using the specified <see cref="SendOption"/>.
+ /// </summary>
+ /// <param name="msg">The message to send.</param>
+ public abstract SendErrors Send(MessageWriter msg);
+
+ /// <summary>
+ /// Connects the connection to a server and begins listening.
+ /// This method blocks and may thrown if there is a problem connecting.
+ /// </summary>
+ /// <param name="bytes">The bytes of data to send in the handshake.</param>
+ /// <param name="timeout">The number of milliseconds to wait before giving up on the connect attempt.</param>
+ public abstract void Connect(byte[] bytes = null, int timeout = 5000);
+
+ /// <summary>
+ /// Connects the connection to a server and begins listening.
+ /// This method does not block.
+ /// </summary>
+ /// <param name="bytes">The bytes of data to send in the handshake.</param>
+ public abstract void ConnectAsync(byte[] bytes = null);
+
+ /// <summary>
+ /// Invokes the DataReceived event.
+ /// </summary>
+ /// <param name="msg">The bytes received.</param>
+ /// <param name="sendOption">The <see cref="SendOption"/> the message was received with.</param>
+ /// <remarks>
+ /// Invokes the <see cref="DataReceived"/> event on this connection to alert subscribers a new message has been
+ /// received. The bytes and the send option that the message was sent with should be passed in to give to the
+ /// subscribers.
+ /// </remarks>
+ protected void InvokeDataReceived(MessageReader msg, SendOption sendOption)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<DataReceivedEventArgs> handler = DataReceived;
+ if (handler != null)
+ {
+ try
+ {
+ handler(new DataReceivedEventArgs(this, msg, sendOption));
+ }
+ catch { }
+ }
+ else
+ {
+ msg.Recycle();
+ }
+ }
+
+ /// <summary>
+ /// Invokes the Disconnected event.
+ /// </summary>
+ /// <param name="e">The exception, if any, that occurred to cause this.</param>
+ /// <param name="reader">Extra disconnect data</param>
+ /// <remarks>
+ /// Invokes the <see cref="Disconnected"/> event to alert subscribres this connection has been disconnected either
+ /// by the end point or because an error occurred. If an error occurred the error should be passed in in order to
+ /// pass to the subscribers, otherwise null can be passed in.
+ /// </remarks>
+ protected void InvokeDisconnected(string e, MessageReader reader)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ EventHandler<DisconnectedEventArgs> handler = Disconnected;
+ if (handler != null)
+ {
+ DisconnectedEventArgs args = new DisconnectedEventArgs(e, reader);
+ try
+ {
+ handler(this, args);
+ }
+ catch
+ {
+ }
+ }
+ }
+
+ /// <summary>
+ /// For times when you want to force the disconnect handler to fire as well as close it.
+ /// If you only want to close it, just use Dispose.
+ /// </summary>
+ public abstract void Disconnect(string reason, MessageWriter writer = null);
+
+ /// <summary>
+ /// Disposes of this NetworkConnection.
+ /// </summary>
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ /// <summary>
+ /// Disposes of this NetworkConnection.
+ /// </summary>
+ /// <param name="disposing">Are we currently disposing?</param>
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ this.DataReceived = null;
+ this.Disconnected = null;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionListener.cs b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs
new file mode 100644
index 0000000..f952847
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs
@@ -0,0 +1,160 @@
+using System;
+using System.Net;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Base class for all connection listeners.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// ConnectionListeners are server side objects that listen for clients and create matching server side connections
+ /// for each client in a similar way to TCP does. These connections should be ready for communication immediately.
+ /// </para>
+ /// <para>
+ /// Each time a client connects the <see cref="NewConnection"/> event will be invoked to alert all subscribers to
+ /// the new connection. A disconnected event is then present on the <see cref="Connection"/> that is passed to the
+ /// subscribers.
+ /// </para>
+ /// </remarks>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class ConnectionListener : IDisposable
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ public readonly ListenerStatistics Statistics = new ListenerStatistics();
+
+ public abstract double AveragePing { get; }
+ public abstract int ConnectionCount { get; }
+ public abstract int SendQueueLength { get; }
+ public abstract int ReceiveQueueLength { get; }
+
+ /// <summary>
+ /// A callback for early connection rejection.
+ /// * Return false to reject connection.
+ /// * A null response is ok, we just won't send anything.
+ /// </summary>
+ public AcceptConnectionCheck AcceptConnection;
+ public delegate bool AcceptConnectionCheck(IPEndPoint endPoint, byte[] input, out byte[] response);
+
+ /// <summary>
+ /// Invoked when a new client connects.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// NewConnection is invoked each time a client connects to the listener. The
+ /// <see cref="NewConnectionEventArgs"/> contains the new <see cref="Connection"/> for communication with this
+ /// client.
+ /// </para>
+ /// <para>
+ /// Hazel may or may not store connections so it is your responsibility to keep track and properly Dispose of
+ /// connections to your server.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/>
+ /// </example>
+ public event Action<NewConnectionEventArgs> NewConnection;
+
+ /// <summary>
+ /// Invoked when an internal error causes the listener to be unable to continue handling messages.
+ /// </summary>
+ /// <remarks>
+ /// Support for this is still pretty limited. At the time of writing, only iOS devices need this in one case:
+ /// When iOS suspends an app, it might also free our socket while not allowing Unity to run in the background.
+ /// When Unity resumes, it can't know that time passed or the socket is freed, so we used to continuously throw internal errors.
+ /// </remarks>
+ public event Action<HazelInternalErrors> OnInternalError;
+
+ /// <summary>
+ /// Makes this connection listener begin listening for connections.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This instructs the listener to begin listening for new clients connecting to the server. When a new client
+ /// connects the <see cref="NewConnection"/> event will be invoked containing the connection to the new client.
+ /// </para>
+ /// <para>
+ /// To stop listening you should call <see cref="Dispose()"/>.
+ /// </para>
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/>
+ /// </example>
+ public abstract void Start();
+
+ /// <summary>
+ /// Invokes the NewConnection event with the supplied connection.
+ /// </summary>
+ /// <param name="msg">The user sent bytes that were received as part of the handshake.</param>
+ /// <param name="connection">The connection to pass in the arguments.</param>
+ /// <remarks>
+ /// Implementers should call this to invoke the <see cref="NewConnection"/> event before data is received so that
+ /// subscribers do not miss any data that may have been sent immediately after connecting.
+ /// </remarks>
+ protected void InvokeNewConnection(MessageReader msg, Connection connection)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<NewConnectionEventArgs> handler = NewConnection;
+ if (handler != null)
+ {
+ try
+ {
+ handler(new NewConnectionEventArgs(msg, connection));
+ }
+ catch (Exception e)
+ {
+ }
+ }
+ }
+
+
+ /// <summary>
+ /// Invokes the InternalError event with the supplied reason.
+ /// </summary>
+ protected void InvokeInternalError(HazelInternalErrors reason)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<HazelInternalErrors> handler = this.OnInternalError;
+ if (handler != null)
+ {
+ try
+ {
+ handler(reason);
+ }
+ catch
+ {
+ }
+ }
+ }
+
+ /// <summary>
+ /// Call to dispose of the connection listener.
+ /// </summary>
+ public void Dispose()
+ {
+ Dispose(true);
+ }
+
+ /// <summary>
+ /// Called when the object is being disposed.
+ /// </summary>
+ /// <param name="disposing">Are we disposing?</param>
+ protected virtual void Dispose(bool disposing)
+ {
+ this.NewConnection = null;
+ this.OnInternalError = null;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionState.cs b/Tools/Hazel-Networking/Hazel/ConnectionState.cs
new file mode 100644
index 0000000..5d3f5c9
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionState.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Represents the state a <see cref="Connection"/> is currently in.
+ /// </summary>
+ public enum ConnectionState
+ {
+ /// <summary>
+ /// The Connection has either not been established yet or has been disconnected.
+ /// </summary>
+ NotConnected,
+
+ /// <summary>
+ /// The Connection is currently connecting to an endpoint.
+ /// </summary>
+ Connecting,
+
+ /// <summary>
+ /// The Connection is connected and data can be transfered.
+ /// </summary>
+ Connected,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs
new file mode 100644
index 0000000..f2c3ed9
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs
@@ -0,0 +1,574 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Holds statistics about the traffic through a <see cref="Connection"/>.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public class ConnectionStatistics
+ {
+ private const int ExpectedMTU = 1200;
+
+ /// <summary>
+ /// The total number of messages sent.
+ /// </summary>
+ public int MessagesSent
+ {
+ get
+ {
+ return UnreliableMessagesSent + ReliableMessagesSent + FragmentedMessagesSent + AcknowledgementMessagesSent + HelloMessagesSent;
+ }
+ }
+
+ private int packetsSent;
+ public int PacketsSent => this.packetsSent;
+
+ private int reliablePacketsAcknowledged;
+ public int ReliablePacketsAcknowledged => this.reliablePacketsAcknowledged;
+
+ /// <summary>
+ /// The number of messages sent larger than 576 bytes. This is smaller than most default MTUs.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int FragmentableMessagesSent
+ {
+ get
+ {
+ return fragmentableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of messages sent larger than 576 bytes.
+ /// </summary>
+ int fragmentableMessagesSent;
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int UnreliableMessagesSent
+ {
+ get
+ {
+ return unreliableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ int unreliableMessagesSent;
+
+ /// <summary>
+ /// The number of reliable messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of reliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogReliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int ReliableMessagesSent
+ {
+ get
+ {
+ return reliableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ int reliableMessagesSent;
+
+ /// <summary>
+ /// The number of fragmented messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of fragmented messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogFragmentedSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int FragmentedMessagesSent
+ {
+ get
+ {
+ return fragmentedMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of fragmented messages sent.
+ /// </summary>
+ int fragmentedMessagesSent;
+
+ /// <summary>
+ /// The number of acknowledgement messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of acknowledgements that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogAcknowledgementSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int AcknowledgementMessagesSent
+ {
+ get
+ {
+ return acknowledgementMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of acknowledgement messages sent.
+ /// </summary>
+ int acknowledgementMessagesSent;
+
+ /// <summary>
+ /// The number of hello messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogHelloSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int HelloMessagesSent
+ {
+ get
+ {
+ return helloMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages sent.
+ /// </summary>
+ int helloMessagesSent;
+
+ /// <summary>
+ /// The number of bytes of data sent.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the number of bytes of data (i.e. user bytes) that were sent from the <see cref="Connection"/>,
+ /// accumulated each time that LogSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </para>
+ /// <para>
+ /// For the number of bytes including protocol bytes see <see cref="TotalBytesSent"/>.
+ /// </para>
+ /// </remarks>
+ public long DataBytesSent
+ {
+ get
+ {
+ return Interlocked.Read(ref dataBytesSent);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes of data sent.
+ /// </summary>
+ long dataBytesSent;
+
+ /// <summary>
+ /// The number of bytes sent in total.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the total number of bytes (the data bytes plus protocol bytes) that were sent from the
+ /// <see cref="Connection"/>, accumulated each time that LogSend is called by the Connection. Messages that
+ /// caused an error are not counted and messages are only counted once all other operations in the send are
+ /// complete.
+ /// </para>
+ /// <para>
+ /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesSent"/>.
+ /// </para>
+ /// </remarks>
+ public long TotalBytesSent
+ {
+ get
+ {
+ return Interlocked.Read(ref totalBytesSent);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes sent in total.
+ /// </summary>
+ long totalBytesSent;
+
+ /// <summary>
+ /// The total number of messages received.
+ /// </summary>
+ public int MessagesReceived
+ {
+ get
+ {
+ return UnreliableMessagesReceived + ReliableMessagesReceived + FragmentedMessagesReceived + AcknowledgementMessagesReceived + helloMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int UnreliableMessagesReceived
+ {
+ get
+ {
+ return unreliableMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages received.
+ /// </summary>
+ int unreliableMessagesReceived;
+
+ /// <summary>
+ /// The number of reliable messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of reliable messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogReliableReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int ReliableMessagesReceived
+ {
+ get
+ {
+ return reliableMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of reliable messages received.
+ /// </summary>
+ int reliableMessagesReceived;
+
+ /// <summary>
+ /// The number of fragmented messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of fragmented messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogFragmentedReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int FragmentedMessagesReceived
+ {
+ get
+ {
+ return fragmentedMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of fragmented messages received.
+ /// </summary>
+ int fragmentedMessagesReceived;
+
+ /// <summary>
+ /// The number of acknowledgement messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of acknowledgement messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogAcknowledgemntReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int AcknowledgementMessagesReceived
+ {
+ get
+ {
+ return acknowledgementMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of acknowledgement messages received.
+ /// </summary>
+ int acknowledgementMessagesReceived;
+
+ /// <summary>
+ /// The number of ping messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int PingMessagesReceived
+ {
+ get
+ {
+ return pingMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ int pingMessagesReceived;
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int HelloMessagesReceived
+ {
+ get
+ {
+ return helloMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ int helloMessagesReceived;
+
+ /// <summary>
+ /// The number of bytes of data received.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the number of bytes of data (i.e. user bytes) that were received by the <see cref="Connection"/>,
+ /// accumulated each time that LogReceive is called by the Connection. Messages are counted before the receive
+ /// event is invoked.
+ /// </para>
+ /// <para>
+ /// For the number of bytes including protocol bytes see <see cref="TotalBytesReceived"/>.
+ /// </para>
+ /// </remarks>
+ public long DataBytesReceived
+ {
+ get
+ {
+ return Interlocked.Read(ref dataBytesReceived);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes of data received.
+ /// </summary>
+ long dataBytesReceived;
+
+ /// <summary>
+ /// The number of bytes received in total.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the total number of bytes (the data bytes plus protocol bytes) that were received by the
+ /// <see cref="Connection"/>, accumulated each time that LogReceive is called by the Connection. Messages are
+ /// counted before the receive event is invoked.
+ /// </para>
+ /// <para>
+ /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesReceived"/>.
+ /// </para>
+ /// </remarks>
+ public long TotalBytesReceived
+ {
+ get
+ {
+ return Interlocked.Read(ref totalBytesReceived);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes received in total.
+ /// </summary>
+ long totalBytesReceived;
+
+ public int MessagesResent { get { return messagesResent; } }
+ int messagesResent;
+
+ /// <summary>
+ /// Logs the sending of an unreliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogUnreliableSend(int dataLength)
+ {
+ Interlocked.Increment(ref unreliableMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+
+ }
+
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ internal void LogPacketSend(int totalLength)
+ {
+ Interlocked.Increment(ref this.packetsSent);
+ Interlocked.Add(ref totalBytesSent, totalLength);
+
+ if (totalLength > ExpectedMTU)
+ {
+ Interlocked.Increment(ref fragmentableMessagesSent);
+ }
+ }
+
+ /// <summary>
+ /// Logs the sending of a reliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogReliableSend(int dataLength)
+ {
+ Interlocked.Increment(ref reliableMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+ }
+
+ /// <summary>
+ /// Logs the sending of a fragmented data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogFragmentedSend(int dataLength)
+ {
+ Interlocked.Increment(ref fragmentedMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+ }
+
+ /// <summary>
+ /// Logs the sending of a acknowledgement data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogAcknowledgementSend()
+ {
+ Interlocked.Increment(ref acknowledgementMessagesSent);
+ }
+
+ /// <summary>
+ /// Logs the sending of a hellp data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogHelloSend()
+ {
+ Interlocked.Increment(ref helloMessagesSent);
+ }
+
+ /// <summary>
+ /// Logs the receiving of an unreliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogUnreliableReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref unreliableMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a reliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogReliableReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref reliableMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a fragmented data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogFragmentedReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref fragmentedMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of an acknowledgement data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogAcknowledgementReceive(int totalLength)
+ {
+ Interlocked.Increment(ref acknowledgementMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the unique acknowledgement of a ping or reliable data packet.
+ /// </summary>
+ internal void LogReliablePacketAcknowledged()
+ {
+ Interlocked.Increment(ref this.reliablePacketsAcknowledged);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a hello data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogPingReceive(int totalLength)
+ {
+ Interlocked.Increment(ref pingMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a hello data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogHelloReceive(int totalLength)
+ {
+ Interlocked.Increment(ref helloMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ internal void LogMessageResent()
+ {
+ Interlocked.Increment(ref messagesResent);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs
new file mode 100644
index 0000000..bfbbc01
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs
@@ -0,0 +1,369 @@
+using System;
+using System.Diagnostics;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// Implementation of AEAD_AES128_GCM based on:
+ /// * RFC 5116 [1]
+ /// * NIST SP 800-38d [2]
+ ///
+ /// [1] https://tools.ietf.org/html/rfc5116
+ /// [2] https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
+ ///
+ /// Adapted from: https://gist.github.com/mendsley/777e6bd9ae7eddcb2b0c0fe18247dc60
+ /// </summary>
+ public class Aes128Gcm : IDisposable
+ {
+ public const int KeySize = 16;
+ public const int NonceSize = 12;
+ public const int CiphertextOverhead = TagSize;
+
+ private const int TagSize = 16;
+
+ private readonly IAes encryptor_;
+
+ private readonly ByteSpan hashSubkey_;
+ private readonly ByteSpan blockJ_;
+ private readonly ByteSpan blockS_;
+ private readonly ByteSpan blockZ_;
+ private readonly ByteSpan blockV_;
+ private readonly ByteSpan blockScratch_;
+
+ /// <summary>
+ /// Creates a new instance of an AEAD_AES128_GCM cipher
+ /// </summary>
+ /// <param name="key">Symmetric key</param>
+ public Aes128Gcm(ByteSpan key)
+ {
+ if (key.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid key length", nameof(key));
+ }
+
+ // Create the AES block cipher
+ this.encryptor_ = CryptoProvider.CreateAes(key);
+
+ // Allocate scratch space
+ ByteSpan scratchSpace = new byte[96];
+ this.hashSubkey_ = scratchSpace.Slice(0, 16);
+ this.blockJ_ = scratchSpace.Slice(16, 16);
+ this.blockS_ = scratchSpace.Slice(32, 16);
+ this.blockZ_ = scratchSpace.Slice(48, 16);
+ this.blockV_ = scratchSpace.Slice(64, 16);
+ this.blockScratch_ = scratchSpace.Slice(80, 16);
+
+ // Create the GHASH subkey by encrypting the 0-block
+ this.encryptor_.EncryptBlock(this.hashSubkey_, this.hashSubkey_);
+ }
+
+ /// <summary>
+ /// Encryptes the specified plaintext and generates an authentication
+ /// tag for the provided additional data. Returns the byte array
+ /// containg both the ciphertext and authentication tag.
+ /// </summary>
+ /// <param name="output">
+ /// Array in which to encode the encrypted ciphertext and
+ /// authentication tag. This array must be large enough to hold
+ /// `plaintext.Lengh + CiphertextOverhead` bytes.
+ /// </param>
+ /// <param name="nonce">Unique value for this message</param>
+ /// <param name="plaintext">Plaintext data to encrypt</param>
+ /// <param name="associatedData">
+ /// Additional data used to authenticate the message
+ /// </param>
+ public void Seal(ByteSpan output, ByteSpan nonce, ByteSpan plaintext, ByteSpan associatedData)
+ {
+ if (nonce.Length != NonceSize)
+ {
+ throw new ArgumentException("Invalid nonce size", nameof(nonce));
+ }
+ if (output.Length < plaintext.Length + CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+
+ // Create the initial counter block
+ nonce.CopyTo(this.blockJ_);
+
+ // Encrypt the plaintext to output
+ GCTR(output, this.blockJ_, 2, plaintext);
+
+ // Generate and append the authentication tag
+ int tagOffset = plaintext.Length;
+ GenerateAuthenticationTag(output.Slice(tagOffset), output.Slice(0, tagOffset), associatedData);
+ }
+
+ /// <summary>
+ /// Validates the authentication tag against the provided additional
+ /// data, then decrypts the cipher text returning the original
+ /// plaintext.
+ /// </summary>
+ /// <param name="nonce">
+ /// The unique value used to seal this message
+ /// </param>
+ /// <param name="ciphertext">
+ /// Combined ciphertext and authentication tag
+ /// </param>
+ /// <param name="associatedData">
+ /// Additional data used to authenticate the message
+ /// </param>
+ /// <param name="output">
+ /// On successful validation and decryprion, Open writes the original
+ /// plaintext to output. Must contain enough space to hold
+ /// `ciphertext.Length - CiphertextOverhead` bytes.
+ /// </param>
+ /// <returns>
+ /// True if the data was validated and successfully decrypted.
+ /// Otherwise, false.
+ /// </returns>
+ public bool Open(ByteSpan output, ByteSpan nonce, ByteSpan ciphertext, ByteSpan associatedData)
+ {
+ if (nonce.Length != NonceSize)
+ {
+ throw new ArgumentException("Invalid nonce size", nameof(nonce));
+ }
+ if (ciphertext.Length < CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid ciphertext size", nameof(ciphertext));
+ }
+ else if (output.Length < ciphertext.Length - CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+
+ // Split ciphertext into actual ciphertext and authentication
+ // tag components.
+ ByteSpan authenticationTag = ciphertext.Slice(ciphertext.Length - TagSize);
+ ciphertext = ciphertext.Slice(0, ciphertext.Length - TagSize);
+
+ // Create the initial counter block
+ nonce.CopyTo(this.blockJ_);
+
+ // Verify the tags match
+ GenerateAuthenticationTag(this.blockScratch_, ciphertext, associatedData);
+ if (0 == Const.ConstantCompareSpans(this.blockScratch_, authenticationTag))
+ {
+ return false;
+ }
+
+ // Decrypt the cipher text to output
+ GCTR(output, this.blockJ_, 2, ciphertext);
+ return true;
+ }
+
+ /// <summary>
+ /// Release resources acquired by the cipher
+ /// </summary>
+ public void Dispose()
+ {
+ this.encryptor_.Dispose();
+ }
+
+ // Generate the authentication tag for a ciphertext+associated data
+ void GenerateAuthenticationTag(ByteSpan output, ByteSpan ciphertext, ByteSpan associatedData)
+ {
+ Debug.Assert(output.Length >= 16);
+
+ // Hash `Associated data || Ciphertext || len(AssociatedD data) || len(Ciphertext)`
+ // into `blockS`
+ {
+ // Clear hash output block
+ SetSpanToZeros(this.blockS_);
+
+ // Write associated data blocks to hash
+ int fullBlocks = associatedData.Length / 16;
+ GHASH(this.blockS_, associatedData, fullBlocks);
+ if (fullBlocks * 16 < associatedData.Length)
+ {
+ SetSpanToZeros(this.blockScratch_);
+ associatedData.Slice(fullBlocks * 16).CopyTo(this.blockScratch_);
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Write ciphertext blocks to hash
+ fullBlocks = ciphertext.Length / 16;
+ GHASH(this.blockS_, ciphertext, fullBlocks);
+ if (fullBlocks * 16 < ciphertext.Length)
+ {
+ SetSpanToZeros(this.blockScratch_);
+ ciphertext.Slice(fullBlocks * 16).CopyTo(this.blockScratch_);
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Write bit sizes to hash
+ ulong associatedDataLengthInBits = (ulong)(8 * associatedData.Length);
+ ulong ciphertextDataLengthInBits = (ulong)(8 * ciphertext.Length);
+ this.blockScratch_.WriteBigEndian64(associatedDataLengthInBits);
+ this.blockScratch_.WriteBigEndian64(ciphertextDataLengthInBits, 8);
+
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Encrypt the tag. GCM requires this because `GASH` is not
+ // cryptographically secure. An attacker could derive our hash
+ // subkey `hashSubkey_` from an unencrypted tag.
+ GCTR(output, this.blockJ_, 1, this.blockS_);
+ }
+
+ // Run the GCTR cipher
+ void GCTR(ByteSpan output, ByteSpan counterBlock, uint counter, ByteSpan data)
+ {
+ Debug.Assert(counterBlock.Length == 16);
+ Debug.Assert(output.Length >= data.Length);
+
+ // Loop through plaintext blocks
+ int writeIndex = 0;
+ int numBlocks = (data.Length + 15) / 16;
+ for (int ii = 0; ii != numBlocks; ++ii)
+ {
+ // Encode counter into block
+ // CB[1] = J0
+ // CB[i] = inc[32](CB[i-1])
+ counterBlock.WriteBigEndian32(counter, 12);
+ ++counter;
+
+ // CIPH[k](CB[i])
+ this.encryptor_.EncryptBlock(counterBlock.Slice(0, 16), this.blockScratch_);
+
+ // Y[i] = X[i] xor CIPH[k](CB[i])
+ for (int jj = 0; jj != 16 && writeIndex < data.Length; ++jj, ++writeIndex)
+ {
+ output[writeIndex] = (byte)(data[writeIndex] ^ this.blockScratch_[jj]);
+ }
+ }
+ }
+
+ // Run the GHASH function
+ void GHASH(ByteSpan output, ByteSpan data, int numBlocks)
+ {
+ ///TODO(mendsley): See Ref[6] for opitmizations of GHASH on both hardware and software
+ ///
+ ///[6] D. McGrew, J. Viega, The Galois/Counter Mode of Operation (GCM), Natl. Inst. Stand.
+ ///Technol. [Web page], http://www.csrc.nist.gov/groups/ST/toolkit/BCM/documents/
+ ///proposedmodes / gcm / gcm - revised - spec.pdf, May 31, 2005.
+
+ Debug.Assert(output.Length == 16);
+ Debug.Assert(data.Length >= numBlocks * 16);
+
+ int readIndex = 0;
+ for (int ii = 0; ii != numBlocks; ++ii)
+ {
+ for (int jj = 0; jj != 16; ++jj, ++readIndex)
+ {
+ // Y[ii-1] xor X[ii]
+ output[jj] ^= data[readIndex];
+ }
+
+ // Y[ii] = (Y[ii-1] xor X[ii]) · H
+ MultiplyGF128Elements(output, this.hashSubkey_, this.blockZ_, this.blockV_);
+ }
+ }
+
+ // Multiply two Galois field elements `X` and `Y` together and store
+ // the result in `X` such that at the end of the function:
+ // X = X·Y
+ static void MultiplyGF128Elements(ByteSpan X, ByteSpan Y, ByteSpan scratchZ, ByteSpan scratchV)
+ {
+ Debug.Assert(X.Length == 16);
+ Debug.Assert(Y.Length == 16);
+ Debug.Assert(scratchZ.Length == 16);
+ Debug.Assert(scratchV.Length == 16);
+
+ // Galois (finite) fields represented by GF(p) define a set of
+ // closed algebraic operations. For AES128_GCM we'll be dealing
+ // with the GF(2^128) field.
+ //
+ // We treat each incoming 16 byte block as a polynomial in field
+ // and define multiplication between two polynomials as the
+ // polynomial product reduced by (mod) the field polynomial:
+ // 1 + x + x^2 + x^7 + x^128
+ //
+ // Field polynomials are represented by a 128 bit string. Bit n is
+ // the coefficient of the x^n term. We use little-endian bit
+ // ordering (not to be confused with byte ordering) for these
+ // coefficients. E.g. X[0] & 0x00000001 represents the 7th bit in
+ // the bit string defined by X, _not_ the 0th bit.
+ //
+
+ // What follows is a modified version of the "peasant's algorithm"
+ // to multiply two numbers:
+ //
+ // Z contains the accumulated product
+ // V is a copy of Y (so we can modify it via shifting).
+ //
+ // We calculate Z = X·V as follows
+ // We loop through each of the 128 bits in X maintaining the
+ // following loop invariant: X·V + Z = the final product
+ //
+ // On each iteration `ii`:
+ //
+ // If the `ii`th bit of `X` is set, add the add the polynomial
+ // in `V` to `X`: `X[n] = X[n] ^ V[n]`
+ //
+ // Double V (Shift one bit right since we're storing little
+ // endian bit). This has the effect of multiplying V by the
+ // polynomial `x`. We track the unrepresentable coefficient
+ // of `x^128` by storing the most significant bit before the
+ // shift `V[15] >> 7` as `carry`
+ //
+ // Check if we've overflowed our multiplication. If overflow
+ // occurred, there will be a non-zero coefficient for the
+ // `x^128` term in the step above `carry`
+ //
+ // If we have overflowed, our polynomial is exactly of degree
+ // 129 (since we're only multiplying by `x`). We reduce the
+ // polynomial back into degree 128 by adding our field's
+ // irreducible polynomial: 1 + x + x^2 + x^7 + x^128. This
+ // reduction cancels out the x^128 term (x^128 + x^128 in GF(2)
+ // is zero). Therefore this modulo can be achieved by simply
+ // adding the irreducible polynomial to the new value of `V`. The
+ // irreducible polynomial is represented by the bit string:
+ // `11100001` followed by 120 `0`s. We can add this value to `V`
+ // by: `V[0] = V[0] ^ 0xE1`.
+ SetSpanToZeros(scratchZ);
+ X.CopyTo(scratchV);
+
+ for (int ii = 0; ii != 128; ++ii)
+ {
+ int bitIndex = 7 - (ii % 8);
+ if ((Y[ii / 8] & (1 << bitIndex)) != 0)
+ {
+ for (int jj = 0; jj != 16; ++jj)
+ {
+ scratchZ[jj] ^= scratchV[jj];
+ }
+ }
+
+ bool carry = false;
+ for (int jj = 0; jj != 16; ++jj)
+ {
+ bool newCarry = (scratchV[jj] & 0x01) != 0;
+ scratchV[jj] >>= 1;
+ if (carry)
+ {
+ scratchV[jj] |= 0x80;
+ }
+ carry = newCarry;
+ }
+
+ if (carry)
+ {
+ scratchV[0] ^= 0xE1;
+ }
+ }
+
+ scratchZ.CopyTo(X);
+ }
+
+ // Set the contents of a span to all zero
+ static void SetSpanToZeros(ByteSpan span)
+ {
+ for (int ii = 0, nn = span.Length; ii != nn; ++ii)
+ {
+ span[ii] = 0;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Const.cs b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs
new file mode 100644
index 0000000..4dfef47
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs
@@ -0,0 +1,82 @@
+using System.Diagnostics;
+
+namespace Hazel.Crypto
+{
+ public static class Const
+ {
+
+ /// <summary>
+ /// Compare two bytes for equality.
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// This is aceived by collapsing the xor bits down into a single bit.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` is the two bytes or equivalent. Otherwise, returns `0`
+ /// </returns>
+ public static byte ConstantCompareByte(byte a, byte b)
+ {
+ byte result = (byte)(~(a ^ b));
+
+ // collapse bits down to the LSB
+ result &= (byte)(result >> 4);
+ result &= (byte)(result >> 2);
+ result &= (byte)(result >> 1);
+
+ return result;
+ }
+
+ /// <summary>
+ /// Compare two equal length spans for equality.
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` if the spans are equivalent. Others, returns `0`.
+ /// </returns>
+ public static byte ConstantCompareSpans(ByteSpan a, ByteSpan b)
+ {
+ Debug.Assert(a.Length == b.Length);
+
+ byte value = 0;
+ for (int ii = 0, nn = a.Length; ii != nn; ++ii)
+ {
+ value |= (byte)(a[ii] ^ b[ii]);
+ }
+
+ return ConstantCompareByte(value, 0);
+ }
+
+ /// <summary>
+ /// Compare a span against an all zero span
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` if the spans is all zeros. Others, returns `0`.
+ /// </returns>
+ public static byte ConstantCompareZeroSpan(ByteSpan a)
+ {
+ byte value = 0;
+ for (int ii = 0, nn = a.Length; ii != nn; ++ii)
+ {
+ value |= (byte)(a[ii] ^ 0);
+ }
+
+ return ConstantCompareByte(value, 0);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs
new file mode 100644
index 0000000..2c56c70
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Crypto
+{
+ public static class CryptoProvider
+ {
+ public delegate IAes CreateAesOverrideDelegate(ByteSpan key);
+
+ /// <summary>
+ /// Override the default AES creation function
+ /// </summary>
+ public static CreateAesOverrideDelegate OverrideCreateAes = null;
+
+ /// <summary>
+ /// Create a new AES cipher
+ /// </summary>
+ /// <param name="key">Encrtyption key</param>
+ public static IAes CreateAes(ByteSpan key)
+ {
+ if (OverrideCreateAes != null)
+ {
+ IAes result = OverrideCreateAes(key);
+ if (null != result)
+ {
+ return result;
+ }
+ }
+
+ return new DefaultAes(key);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs
new file mode 100644
index 0000000..da72fb8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs
@@ -0,0 +1,49 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// AES provider using the default System.Security.Cryptography implementation
+ /// </summary>
+ public class DefaultAes : IAes
+ {
+ private readonly ICryptoTransform encryptor_;
+
+ /// <summary>
+ /// Create a new default instance of the AES block cipher
+ /// </summary>
+ /// <param name="key">Encryption key</param>
+ public DefaultAes(ByteSpan key)
+ {
+ // Create the AES block cipher
+ using (Aes aes = Aes.Create())
+ {
+ aes.KeySize = key.Length * 8;
+ aes.BlockSize = aes.KeySize;
+ aes.Mode = CipherMode.ECB;
+ aes.Padding = PaddingMode.Zeros;
+ aes.Key = key.ToArray();
+
+ this.encryptor_ = aes.CreateEncryptor();
+ }
+ }
+
+ /// <inheritdoc/>
+ public void Dispose()
+ {
+ this.encryptor_.Dispose();
+ }
+
+ /// <inheritdoc/>
+ public int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan)
+ {
+ if (inputSpan.Length != outputSpan.Length)
+ {
+ throw new ArgumentException($"ouputSpan length ({outputSpan.Length}) does not match inputSpan length ({inputSpan.Length})", nameof(outputSpan));
+ }
+
+ return this.encryptor_.TransformBlock(inputSpan.GetUnderlyingArray(), inputSpan.Offset, inputSpan.Length, outputSpan.GetUnderlyingArray(), outputSpan.Offset);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs
new file mode 100644
index 0000000..6c494cd
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// AES encryption interface
+ /// </summary>
+ public interface IAes : IDisposable
+ {
+ /// <summary>
+ /// Encrypts the specified region of the input byte array and copies
+ /// the resulting transform to the specified region of the output
+ /// array.
+ /// </summary>
+ /// <param name="inputSpan">The input for which to encrypt</param>
+ /// <param name="outputSpan">
+ /// The otput to which to write the encrypted data. This span can
+ /// overlap with `inputSpan`.
+ /// </param>
+ /// <returns>The number of bytes written</returns>
+ int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan);
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs
new file mode 100644
index 0000000..1903693
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs
@@ -0,0 +1,86 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// Streams data into a SHA256 digest
+ /// </summary>
+ public class Sha256Stream : IDisposable
+ {
+ /// <summary>
+ /// Size of the SHA256 digest in bytes
+ /// </summary>
+ public const int DigestSize = 32;
+
+ private SHA256 hash = SHA256.Create();
+ private bool isHashFinished = false;
+
+ struct EmptyArray
+ {
+ public static readonly byte[] Value = new byte[0];
+ }
+
+ /// <summary>
+ /// Create a new instance of a SHA256 stream
+ /// </summary>
+ public Sha256Stream()
+ {
+ }
+
+ /// <summary>
+ /// Release resources associated with the stream
+ /// </summary>
+ public void Dispose()
+ {
+ this.hash?.Dispose();
+ this.hash = null;
+
+ GC.SuppressFinalize(this);
+ }
+
+ /// <summary>
+ /// Reset the stream to its initial state
+ /// </summary>
+ public void Reset()
+ {
+ this.hash?.Dispose();
+ this.hash = SHA256.Create();
+ this.isHashFinished = false;
+ }
+
+ /// <summary>
+ /// Add data to the stream
+ /// </summary>
+ public void AddData(ByteSpan data)
+ {
+ while (data.Length > 0)
+ {
+ int offset = this.hash.TransformBlock(data.GetUnderlyingArray(), data.Offset, data.Length, null, 0);
+ data = data.Slice(offset);
+ }
+ }
+
+ /// <summary>
+ /// Calculate the final hash of the stream data
+ /// </summary>
+ /// <param name="output">
+ /// Target span to which the hash will be written
+ /// </param>
+ public void CopyOrCalculateFinalHash(ByteSpan output)
+ {
+ if (output.Length != DigestSize)
+ {
+ throw new ArgumentException($"Expected a span of {DigestSize} bytes. Got a span of {output.Length} bytes", nameof(output));
+ }
+
+ if (this.isHashFinished == false)
+ {
+ this.hash.TransformFinalBlock(EmptyArray.Value, 0, 0);
+ this.isHashFinished = true;
+ }
+
+ new ByteSpan(this.hash.Hash).CopyTo(output);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs
new file mode 100644
index 0000000..03164ec
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ public static class SpanCryptoExtensions
+ {
+ /// <summary>
+ /// Clear a span's contents to zero
+ /// </summary>
+ public static void SecureClear(this ByteSpan span)
+ {
+ if (span.Length > 0)
+ {
+ Array.Clear(span.GetUnderlyingArray(), span.Offset, span.Length);
+ }
+ }
+
+ /// <summary>
+ /// Fill a byte span with random data
+ /// </summary>
+ /// <param name="random">Entropy source</param>
+ public static void FillWithRandom(this ByteSpan span, RandomNumberGenerator random)
+ {
+ if (span.Offset == 0 && span.Length == span.GetUnderlyingArray().Length)
+ {
+ random.GetBytes(span.GetUnderlyingArray());
+ return;
+ }
+
+ byte[] temp = new byte[span.Length];
+ random.GetBytes(temp);
+ new ByteSpan(temp).CopyTo(span);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs
new file mode 100644
index 0000000..3f4624b
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs
@@ -0,0 +1,844 @@
+using System;
+using System.Diagnostics;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// The x25519 key agreement algorithm
+ /// </summary>
+ public static class X25519
+ {
+ public const int KeySize = 32;
+
+ /// <summary>
+ /// Element in the GF(2^255 - 19) field
+ /// </summary>
+ public partial struct FieldElement
+ {
+ public int x0, x1, x2, x3, x4;
+ public int x5, x6, x7, x8, x9;
+ };
+
+ private static readonly byte[] BasePoint = {9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+ /// <summary>
+ /// Performs the core x25519 function: Multiplying an EC point by a scalar value
+ /// </summary>
+ public static bool Func(ByteSpan output, ByteSpan scalar, ByteSpan point)
+ {
+ InternalFunc(output, scalar, point);
+ if (Const.ConstantCompareZeroSpan(output) == 1)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Multiplies the base x25519 point by the provided scalar value
+ /// </summary>
+ public static void Func(ByteSpan output, ByteSpan scalar)
+ {
+ InternalFunc(output, scalar, BasePoint);
+ }
+
+ // The FieldElement code below is ported from the original
+ // public domain reference implemtation of X25519
+ // by D. J. Bernstien
+ //
+ // See: https://cr.yp.to/ecdh.html
+
+ private static void InternalFunc(ByteSpan output, ByteSpan scalar, ByteSpan point)
+ {
+ if (output.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+ else if (scalar.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid scalar size", nameof(scalar));
+ }
+ else if (point.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid point size", nameof(point));
+ }
+
+ // copy the scalar so we can properly mask it
+ ByteSpan maskedScalar = new byte[32];
+ scalar.CopyTo(maskedScalar);
+ maskedScalar[0] &= 248;
+ maskedScalar[31] &= 127;
+ maskedScalar[31] |= 64;
+
+ FieldElement x1 = FieldElement.FromBytes(point);
+ FieldElement x2 = FieldElement.One();
+ FieldElement x3 = x1;
+ FieldElement z2 = FieldElement.Zero();
+ FieldElement z3 = FieldElement.One();
+
+ FieldElement tmp0 = new FieldElement();
+ FieldElement tmp1 = new FieldElement();
+
+ int swap = 0;
+ for (int pos = 254; pos >= 0; --pos)
+ {
+ int b = (int)maskedScalar[pos / 8] >> (int)(pos % 8);
+ b &= 1;
+ swap ^= b;
+
+ FieldElement.ConditionalSwap(ref x2, ref x3, swap);
+ FieldElement.ConditionalSwap(ref z2, ref z3, swap);
+ swap = b;
+
+ FieldElement.Sub(ref tmp0, ref x3, ref z3);
+ FieldElement.Sub(ref tmp1, ref x2, ref z2);
+ FieldElement.Add(ref x2, ref x2, ref z2);
+ FieldElement.Add(ref z2, ref x3, ref z3);
+ FieldElement.Multiply(ref z3, ref tmp0, ref x2);
+ FieldElement.Multiply(ref z2, ref z2, ref tmp1);
+ FieldElement.Square(ref tmp0, ref tmp1);
+ FieldElement.Square(ref tmp1, ref x2);
+ FieldElement.Add(ref x3, ref z3, ref z2);
+ FieldElement.Sub(ref z2, ref z3, ref z2);
+ FieldElement.Multiply(ref x2, ref tmp1, ref tmp0);
+ FieldElement.Sub(ref tmp1, ref tmp1, ref tmp0);
+ FieldElement.Square(ref z2, ref z2);
+ FieldElement.Multiply121666(ref z3, ref tmp1);
+ FieldElement.Square(ref x3, ref x3);
+ FieldElement.Add(ref tmp0, ref tmp0, ref z3);
+ FieldElement.Multiply(ref z3, ref x1, ref z2);
+ FieldElement.Multiply(ref z2, ref tmp1, ref tmp0);
+ }
+
+ FieldElement.ConditionalSwap(ref x2, ref x3, swap);
+ FieldElement.ConditionalSwap(ref z2, ref z3, swap);
+
+ FieldElement.Invert(ref z2, ref z2);
+ FieldElement.Multiply(ref x2, ref x2, ref z2);
+ x2.CopyTo(output);
+ }
+
+
+ /// <summary>
+ /// Mathematical operators over GF(2^255 - 19)
+ /// </summary>
+ partial struct FieldElement
+ {
+ /// <summary>
+ /// Convert a byte array to a field element
+ /// </summary>
+ public static FieldElement FromBytes(ByteSpan bytes)
+ {
+ Debug.Assert(bytes.Length >= KeySize);
+
+ long tmp0 = (long)bytes.ReadLittleEndian32();
+ long tmp1 = (long)bytes.ReadLittleEndian24(4) << 6;
+ long tmp2 = (long)bytes.ReadLittleEndian24(7) << 5;
+ long tmp3 = (long)bytes.ReadLittleEndian24(10) << 3;
+ long tmp4 = (long)bytes.ReadLittleEndian24(13) << 2;
+ long tmp5 = (long)bytes.ReadLittleEndian32(16);
+ long tmp6 = (long)bytes.ReadLittleEndian24(20) << 7;
+ long tmp7 = (long)bytes.ReadLittleEndian24(23) << 5;
+ long tmp8 = (long)bytes.ReadLittleEndian24(26) << 4;
+ long tmp9 = (long)(bytes.ReadLittleEndian24(29) & 0x007FFFFF) << 2;
+
+ long carry9 = (tmp9 + (1L<<24)) >> 25;
+ tmp0 += carry9 * 19;
+ tmp9 -= carry9 << 25;
+ long carry1 = (tmp1 + (1L<<24)) >> 25;
+ tmp2 += carry1;
+ tmp1 -= carry1 << 25;
+ long carry3 = (tmp3 + (1L<<24)) >> 25;
+ tmp4 += carry3;
+ tmp3 -= carry3 << 25;
+ long carry5 = (tmp5 + (1L<<24)) >> 25;
+ tmp6 += carry5;
+ tmp5 -= carry5 << 25;
+ long carry7 = (tmp7 + (1L<<24)) >> 25;
+ tmp8 += carry7;
+ tmp7 -= carry7 << 25;
+
+ long carry0 = (tmp0 + (1L<<25)) >> 26;
+ tmp1 += carry0;
+ tmp0 -= carry0 << 26;
+ long carry2 = (tmp2 + (1L<<25)) >> 26;
+ tmp3 += carry2;
+ tmp2 -= carry2 << 26;
+ long carry4 = (tmp4 + (1L<<25)) >> 26;
+ tmp5 += carry4;
+ tmp4 -= carry4 << 26;
+ long carry6 = (tmp6 + (1L<<25)) >> 26;
+ tmp7 += carry6;
+ tmp6 -= carry6 << 26;
+ long carry8 = (tmp8 + (1L<<25)) >> 26;
+ tmp9 += carry8;
+ tmp8 -= carry8 << 26;
+
+ return new FieldElement
+ {
+ x0 = (int)tmp0,
+ x1 = (int)tmp1,
+ x2 = (int)tmp2,
+ x3 = (int)tmp3,
+ x4 = (int)tmp4,
+ x5 = (int)tmp5,
+ x6 = (int)tmp6,
+ x7 = (int)tmp7,
+ x8 = (int)tmp8,
+ x9 = (int)tmp9,
+ };
+ }
+
+ /// <summary>
+ /// Convert the field element to a byte array
+ /// </summary>
+ public void CopyTo(ByteSpan output)
+ {
+ Debug.Assert(output.Length >= 32);
+
+ long q = (19 * this.x9 + (1L << 24)) >> 25;
+ q = ((long)this.x0 + q) >> 26;
+ q = ((long)this.x1 + q) >> 25;
+ q = ((long)this.x2 + q) >> 26;
+ q = ((long)this.x3 + q) >> 25;
+ q = ((long)this.x4 + q) >> 26;
+ q = ((long)this.x5 + q) >> 25;
+ q = ((long)this.x6 + q) >> 26;
+ q = ((long)this.x7 + q) >> 25;
+ q = ((long)this.x8 + q) >> 26;
+ q = ((long)this.x9 + q) >> 25;
+
+ this.x0 = (int)((long)this.x0 + (19L * q));
+
+ int carry0 = (int)(this.x0 >> 26);
+ this.x1 = (int)((int)this.x1 + carry0);
+ this.x0 = (int)((int)this.x0 - (carry0 << 26));
+ int carry1 = (int)(this.x1 >> 25);
+ this.x2 = (int)((int)this.x2 + carry1);
+ this.x1 = (int)((int)this.x1 - (carry1 << 25));
+ int carry2 = (int)(this.x2 >> 26);
+ this.x3 = (int)((int)this.x3 + carry2);
+ this.x2 = (int)((int)this.x2 - (carry2 << 26));
+ int carry3 = (int)(this.x3 >> 25);
+ this.x4 = (int)((int)this.x4 + carry3);
+ this.x3 = (int)((int)this.x3 - (carry3 << 25));
+ int carry4 = (int)(this.x4 >> 26);
+ this.x5 = (int)((int)this.x5 + carry4);
+ this.x4 = (int)((int)this.x4 - (carry4 << 26));
+ int carry5 = (int)(this.x5 >> 25);
+ this.x6 = (int)((int)this.x6 + carry5);
+ this.x5 = (int)((int)this.x5 - (carry5 << 25));
+ int carry6 = (int)(this.x6 >> 26);
+ this.x7 = (int)((int)this.x7 + carry6);
+ this.x6 = (int)((int)this.x6 - (carry6 << 26));
+ int carry7 = (int)(this.x7 >> 25);
+ this.x8 = (int)((int)this.x8 + carry7);
+ this.x7 = (int)((int)this.x7 - (carry7 << 25));
+ int carry8 = (int)(this.x8 >> 26);
+ this.x9 = (int)((int)this.x9 + carry8);
+ this.x8 = (int)((int)this.x8 - (carry8 << 26));
+ int carry9 = (int)(this.x9 >> 25);
+ this.x9 = (int)((int)this.x9 - (carry9 << 25));
+
+ output[ 0] = (byte)(this.x0 >> 0);
+ output[ 1] = (byte)(this.x0 >> 8);
+ output[ 2] = (byte)(this.x0 >> 16);
+ output[ 3] = (byte)((this.x0 >> 24) | (this.x1 << 2));
+ output[ 4] = (byte)(this.x1 >> 6);
+ output[ 5] = (byte)(this.x1 >> 14);
+ output[ 6] = (byte)((this.x1 >> 22) | (this.x2 << 3));
+ output[ 7] = (byte)(this.x2 >> 5);
+ output[ 8] = (byte)(this.x2 >> 13);
+ output[ 9] = (byte)((this.x2 >> 21) | (this.x3 << 5));
+ output[10] = (byte)(this.x3 >> 3);
+ output[11] = (byte)(this.x3 >> 11);
+ output[12] = (byte)((this.x3 >> 19) | (this.x4 << 6));
+ output[13] = (byte)(this.x4 >> 2);
+ output[14] = (byte)(this.x4 >> 10);
+ output[15] = (byte)(this.x4 >> 18);
+ output[16] = (byte)(this.x5 >> 0);
+ output[17] = (byte)(this.x5 >> 8);
+ output[18] = (byte)(this.x5 >> 16);
+ output[19] = (byte)((this.x5 >> 24) | (this.x6 << 1));
+ output[20] = (byte)(this.x6 >> 7);
+ output[21] = (byte)(this.x6 >> 15);
+ output[22] = (byte)((this.x6 >> 23) | (this.x7 << 3));
+ output[23] = (byte)(this.x7 >> 5);
+ output[24] = (byte)(this.x7 >> 13);
+ output[25] = (byte)((this.x7 >> 21) | (this.x8 << 4));
+ output[26] = (byte)(this.x8 >> 4);
+ output[27] = (byte)(this.x8 >> 12);
+ output[28] = (byte)((this.x8 >> 20) | (this.x9 << 6));
+ output[29] = (byte)(this.x9 >> 2);
+ output[30] = (byte)(this.x9 >> 10);
+ output[31] = (byte)(this.x9 >> 18);
+ }
+
+ /// <summary>
+ /// Set the field element to `0`
+ /// </summary>
+ public static FieldElement Zero()
+ {
+ return new FieldElement();
+ }
+
+ /// <summary>
+ /// Set the field element to `1`
+ /// </summary>
+ public static FieldElement One()
+ {
+ FieldElement result = Zero();
+ result.x0 = 1;
+ return result;
+ }
+
+ /// <summary>
+ /// Add two field elements
+ /// </summary>
+ public static void Add(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ output.x0 = a.x0 + b.x0;
+ output.x1 = a.x1 + b.x1;
+ output.x2 = a.x2 + b.x2;
+ output.x3 = a.x3 + b.x3;
+ output.x4 = a.x4 + b.x4;
+ output.x5 = a.x5 + b.x5;
+ output.x6 = a.x6 + b.x6;
+ output.x7 = a.x7 + b.x7;
+ output.x8 = a.x8 + b.x8;
+ output.x9 = a.x9 + b.x9;
+ }
+
+ /// <summary>
+ /// Subtract two field elements
+ /// </summary>
+ public static void Sub(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ output.x0 = a.x0 - b.x0;
+ output.x1 = a.x1 - b.x1;
+ output.x2 = a.x2 - b.x2;
+ output.x3 = a.x3 - b.x3;
+ output.x4 = a.x4 - b.x4;
+ output.x5 = a.x5 - b.x5;
+ output.x6 = a.x6 - b.x6;
+ output.x7 = a.x7 - b.x7;
+ output.x8 = a.x8 - b.x8;
+ output.x9 = a.x9 - b.x9;
+ }
+
+ /// <summary>
+ /// Multiply two field elements
+ /// </summary>
+ public static void Multiply(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ int b1_19 = 19 * b.x1;
+ int b2_19 = 19 * b.x2;
+ int b3_19 = 19 * b.x3;
+ int b4_19 = 19 * b.x4;
+ int b5_19 = 19 * b.x5;
+ int b6_19 = 19 * b.x6;
+ int b7_19 = 19 * b.x7;
+ int b8_19 = 19 * b.x8;
+ int b9_19 = 19 * b.x9;
+
+ int a1_2 = 2 * a.x1;
+ int a3_2 = 2 * a.x3;
+ int a5_2 = 2 * a.x5;
+ int a7_2 = 2 * a.x7;
+ int a9_2 = 2 * a.x9;
+
+ long a0b0 = (long)a.x0 * (long)b.x0;
+ long a0b1 = (long)a.x0 * (long)b.x1;
+ long a0b2 = (long)a.x0 * (long)b.x2;
+ long a0b3 = (long)a.x0 * (long)b.x3;
+ long a0b4 = (long)a.x0 * (long)b.x4;
+ long a0b5 = (long)a.x0 * (long)b.x5;
+ long a0b6 = (long)a.x0 * (long)b.x6;
+ long a0b7 = (long)a.x0 * (long)b.x7;
+ long a0b8 = (long)a.x0 * (long)b.x8;
+ long a0b9 = (long)a.x0 * (long)b.x9;
+ long a1b0 = (long)a.x1 * (long)b.x0;
+ long a1b1_2 = (long)a1_2 * (long)b.x1;
+ long a1b2 = (long)a.x1 * (long)b.x2;
+ long a1b3_2 = (long)a1_2 * (long)b.x3;
+ long a1b4 = (long)a.x1 * (long)b.x4;
+ long a1b5_2 = (long)a1_2 * (long)b.x5;
+ long a1b6 = (long)a.x1 * (long)b.x6;
+ long a1b7_2 = (long)a1_2 * (long)b.x7;
+ long a1b8 = (long)a.x1 * (long)b.x8;
+ long a1b9_38 = (long)a1_2 * (long)b9_19;
+ long a2b0 = (long)a.x2 * (long)b.x0;
+ long a2b1 = (long)a.x2 * (long)b.x1;
+ long a2b2 = (long)a.x2 * (long)b.x2;
+ long a2b3 = (long)a.x2 * (long)b.x3;
+ long a2b4 = (long)a.x2 * (long)b.x4;
+ long a2b5 = (long)a.x2 * (long)b.x5;
+ long a2b6 = (long)a.x2 * (long)b.x6;
+ long a2b7 = (long)a.x2 * (long)b.x7;
+ long a2b8_19 = (long)a.x2 * (long)b8_19;
+ long a2b9_19 = (long)a.x2 * (long)b9_19;
+ long a3b0 = (long)a.x3 * (long)b.x0;
+ long a3b1_2 = (long)a3_2 * (long)b.x1;
+ long a3b2 = (long)a.x3 * (long)b.x2;
+ long a3b3_2 = (long)a3_2 * (long)b.x3;
+ long a3b4 = (long)a.x3 * (long)b.x4;
+ long a3b5_2 = (long)a3_2 * (long)b.x5;
+ long a3b6 = (long)a.x3 * (long)b.x6;
+ long a3b7_38 = (long)a3_2 * (long)b7_19;
+ long a3b8_19 = (long)a.x3 * (long)b8_19;
+ long a3b9_38 = (long)a3_2 * (long)b9_19;
+ long a4b0 = (long)a.x4 * (long)b.x0;
+ long a4b1 = (long)a.x4 * (long)b.x1;
+ long a4b2 = (long)a.x4 * (long)b.x2;
+ long a4b3 = (long)a.x4 * (long)b.x3;
+ long a4b4 = (long)a.x4 * (long)b.x4;
+ long a4b5 = (long)a.x4 * (long)b.x5;
+ long a4b6_19 = (long)a.x4 * (long)b6_19;
+ long a4b7_19 = (long)a.x4 * (long)b7_19;
+ long a4b8_19 = (long)a.x4 * (long)b8_19;
+ long a4b9_19 = (long)a.x4 * (long)b9_19;
+ long a5b0 = (long)a.x5 * (long)b.x0;
+ long a5b1_2 = (long)a5_2 * (long)b.x1;
+ long a5b2 = (long)a.x5 * (long)b.x2;
+ long a5b3_2 = (long)a5_2 * (long)b.x3;
+ long a5b4 = (long)a.x5 * (long)b.x4;
+ long a5b5_38 = (long)a5_2 * (long)b5_19;
+ long a5b6_19 = (long)a.x5 * (long)b6_19;
+ long a5b7_38 = (long)a5_2 * (long)b7_19;
+ long a5b8_19 = (long)a.x5 * (long)b8_19;
+ long a5b9_38 = (long)a5_2 * (long)b9_19;
+ long a6b0 = (long)a.x6 * (long)b.x0;
+ long a6b1 = (long)a.x6 * (long)b.x1;
+ long a6b2 = (long)a.x6 * (long)b.x2;
+ long a6b3 = (long)a.x6 * (long)b.x3;
+ long a6b4_19 = (long)a.x6 * (long)b4_19;
+ long a6b5_19 = (long)a.x6 * (long)b5_19;
+ long a6b6_19 = (long)a.x6 * (long)b6_19;
+ long a6b7_19 = (long)a.x6 * (long)b7_19;
+ long a6b8_19 = (long)a.x6 * (long)b8_19;
+ long a6b9_19 = (long)a.x6 * (long)b9_19;
+ long a7b0 = (long)a.x7 * (long)b.x0;
+ long a7b1_2 = (long)a7_2 * (long)b.x1;
+ long a7b2 = (long)a.x7 * (long)b.x2;
+ long a7b3_38 = (long)a7_2 * (long)b3_19;
+ long a7b4_19 = (long)a.x7 * (long)b4_19;
+ long a7b5_38 = (long)a7_2 * (long)b5_19;
+ long a7b6_19 = (long)a.x7 * (long)b6_19;
+ long a7b7_38 = (long)a7_2 * (long)b7_19;
+ long a7b8_19 = (long)a.x7 * (long)b8_19;
+ long a7b9_38 = (long)a7_2 * (long)b9_19;
+ long a8b0 = (long)a.x8 * (long)b.x0;
+ long a8b1 = (long)a.x8 * (long)b.x1;
+ long a8b2_19 = (long)a.x8 * (long)b2_19;
+ long a8b3_19 = (long)a.x8 * (long)b3_19;
+ long a8b4_19 = (long)a.x8 * (long)b4_19;
+ long a8b5_19 = (long)a.x8 * (long)b5_19;
+ long a8b6_19 = (long)a.x8 * (long)b6_19;
+ long a8b7_19 = (long)a.x8 * (long)b7_19;
+ long a8b8_19 = (long)a.x8 * (long)b8_19;
+ long a8b9_19 = (long)a.x8 * (long)b9_19;
+ long a9b0 = (long)a.x9 * (long)b.x0;
+ long a9b1_38 = (long)a9_2 * (long)b1_19;
+ long a9b2_19 = (long)a.x9 * (long)b2_19;
+ long a9b3_38 = (long)a9_2 * (long)b3_19;
+ long a9b4_19 = (long)a.x9 * (long)b4_19;
+ long a9b5_38 = (long)a9_2 * (long)b5_19;
+ long a9b6_19 = (long)a.x9 * (long)b6_19;
+ long a9b7_38 = (long)a9_2 * (long)b7_19;
+ long a9b8_19 = (long)a.x9 * (long)b8_19;
+ long a9b9_38 = (long)a9_2 * (long)b9_19;
+
+ long h0 = a0b0 + a1b9_38 + a2b8_19 + a3b7_38 + a4b6_19 + a5b5_38 + a6b4_19 + a7b3_38 + a8b2_19 + a9b1_38;
+ long h1 = a0b1 + a1b0 + a2b9_19 + a3b8_19 + a4b7_19 + a5b6_19 + a6b5_19 + a7b4_19 + a8b3_19 + a9b2_19;
+ long h2 = a0b2 + a1b1_2 + a2b0 + a3b9_38 + a4b8_19 + a5b7_38 + a6b6_19 + a7b5_38 + a8b4_19 + a9b3_38;
+ long h3 = a0b3 + a1b2 + a2b1 + a3b0 + a4b9_19 + a5b8_19 + a6b7_19 + a7b6_19 + a8b5_19 + a9b4_19;
+ long h4 = a0b4 + a1b3_2 + a2b2 + a3b1_2 + a4b0 + a5b9_38 + a6b8_19 + a7b7_38 + a8b6_19 + a9b5_38;
+ long h5 = a0b5 + a1b4 + a2b3 + a3b2 + a4b1 + a5b0 + a6b9_19 + a7b8_19 + a8b7_19 + a9b6_19;
+ long h6 = a0b6 + a1b5_2 + a2b4 + a3b3_2 + a4b2 + a5b1_2 + a6b0 + a7b9_38 + a8b8_19 + a9b7_38;
+ long h7 = a0b7 + a1b6 + a2b5 + a3b4 + a4b3 + a5b2 + a6b1 + a7b0 + a8b9_19 + a9b8_19;
+ long h8 = a0b8 + a1b7_2 + a2b6 + a3b5_2 + a4b4 + a5b3_2 + a6b2 + a7b1_2 + a8b0 + a9b9_38;
+ long h9 = a0b9 + a1b8 + a2b7 + a3b6 + a4b5 + a5b4 + a6b3 + a7b2 + a8b1 + a9b0;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+
+ long carry1 = (h1 + (1L << 24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry5 = (h5 + (1L << 24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+
+ long carry3 = (h3 + (1L << 24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry7 = (h7 + (1L << 24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ long carry9 = (h9 + (1L << 24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+
+ carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Square a field element
+ /// </summary>
+ public static void Square(ref FieldElement output, ref FieldElement a)
+ {
+ int a0_2 = a.x0 * 2;
+ int a1_2 = a.x1 * 2;
+ int a2_2 = a.x2 * 2;
+ int a3_2 = a.x3 * 2;
+ int a4_2 = a.x4 * 2;
+ int a5_2 = a.x5 * 2;
+ int a6_2 = a.x6 * 2;
+ int a7_2 = a.x7 * 2;
+
+ int a5_38 = a.x5 * 38;
+ int a6_19 = a.x6 * 19;
+ int a7_38 = a.x7 * 38;
+ int a8_19 = a.x8 * 19;
+ int a9_38 = a.x9 * 38;
+
+ long a0a0 = (long)a.x0 * (long)a.x0;
+ long a0a1_2 = (long)a0_2 * (long)a.x1;
+ long a0a2_2 = (long)a0_2 * (long)a.x2;
+ long a0a3_2 = (long)a0_2 * (long)a.x3;
+ long a0a4_2 = (long)a0_2 * (long)a.x4;
+ long a0a5_2 = (long)a0_2 * (long)a.x5;
+ long a0a6_2 = (long)a0_2 * (long)a.x6;
+ long a0a7_2 = (long)a0_2 * (long)a.x7;
+ long a0a8_2 = (long)a0_2 * (long)a.x8;
+ long a0a9_2 = (long)a0_2 * (long)a.x9;
+ long a1a1_2 = (long)a1_2 * (long)a.x1;
+ long a1a2_2 = (long)a1_2 * (long)a.x2;
+ long a1a3_4 = (long)a1_2 * (long)a3_2;
+ long a1a4_2 = (long)a1_2 * (long)a.x4;
+ long a1a5_4 = (long)a1_2 * (long)a5_2;
+ long a1a6_2 = (long)a1_2 * (long)a.x6;
+ long a1a7_4 = (long)a1_2 * (long)a7_2;
+ long a1a8_2 = (long)a1_2 * (long)a.x8;
+ long a1a9_76 = (long)a1_2 * (long)a9_38;
+ long a2a2 = (long)a.x2 * (long)a.x2;
+ long a2a3_2 = (long)a2_2 * (long)a.x3;
+ long a2a4_2 = (long)a2_2 * (long)a.x4;
+ long a2a5_2 = (long)a2_2 * (long)a.x5;
+ long a2a6_2 = (long)a2_2 * (long)a.x6;
+ long a2a7_2 = (long)a2_2 * (long)a.x7;
+ long a2a8_38 = (long)a2_2 * (long)a8_19;
+ long a2a9_38 = (long)a.x2 * (long)a9_38;
+ long a3a3_2 = (long)a3_2 * (long)a.x3;
+ long a3a4_2 = (long)a3_2 * (long)a.x4;
+ long a3a5_4 = (long)a3_2 * (long)a5_2;
+ long a3a6_2 = (long)a3_2 * (long)a.x6;
+ long a3a7_76 = (long)a3_2 * (long)a7_38;
+ long a3a8_38 = (long)a3_2 * (long)a8_19;
+ long a3a9_76 = (long)a3_2 * (long)a9_38;
+ long a4a4 = (long)a.x4 * (long)a.x4;
+ long a4a5_2 = (long)a4_2 * (long)a.x5;
+ long a4a6_38 = (long)a4_2 * (long)a6_19;
+ long a4a7_38 = (long)a.x4 * (long)a7_38;
+ long a4a8_38 = (long)a4_2 * (long)a8_19;
+ long a4a9_38 = (long)a.x4 * (long)a9_38;
+ long a5a5_38 = (long)a.x5 * (long)a5_38;
+ long a5a6_38 = (long)a5_2 * (long)a6_19;
+ long a5a7_76 = (long)a5_2 * (long)a7_38;
+ long a5a8_38 = (long)a5_2 * (long)a8_19;
+ long a5a9_76 = (long)a5_2 * (long)a9_38;
+ long a6a6_19 = (long)a.x6 * (long)a6_19;
+ long a6a7_38 = (long)a.x6 * (long)a7_38;
+ long a6a8_38 = (long)a6_2 * (long)a8_19;
+ long a6a9_38 = (long)a.x6 * (long)a9_38;
+ long a7a7_38 = (long)a.x7 * (long)a7_38;
+ long a7a8_38 = (long)a7_2 * (long)a8_19;
+ long a7a9_76 = (long)a7_2 * (long)a9_38;
+ long a8a8_19 = (long)a.x8 * (long)a8_19;
+ long a8a9_38 = (long)a.x8 * (long)a9_38;
+ long a9a9_38 = (long)a.x9 * (long)a9_38;
+
+ long h0 = a0a0 + a1a9_76 + a2a8_38 + a3a7_76 + a4a6_38 + a5a5_38;
+ long h1 = a0a1_2 + a2a9_38 + a3a8_38 + a4a7_38 + a5a6_38;
+ long h2 = a0a2_2 + a1a1_2 + a3a9_76 + a4a8_38 + a5a7_76 + a6a6_19;
+ long h3 = a0a3_2 + a1a2_2 + a4a9_38 + a5a8_38 + a6a7_38;
+ long h4 = a0a4_2 + a1a3_4 + a2a2 + a5a9_76 + a6a8_38 + a7a7_38;
+ long h5 = a0a5_2 + a1a4_2 + a2a3_2 + a6a9_38 + a7a8_38;
+ long h6 = a0a6_2 + a1a5_4 + a2a4_2 + a3a3_2 + a7a9_76 + a8a8_19;
+ long h7 = a0a7_2 + a1a6_2 + a2a5_2 + a3a4_2 + a8a9_38;
+ long h8 = a0a8_2 + a1a7_4 + a2a6_2 + a3a5_4 + a4a4 + a9a9_38;
+ long h9 = a0a9_2 + a1a8_2 + a2a7_2 + a3a6_2 + a4a5_2;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+
+ long carry1 = (h1 + (1L << 24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry5 = (h5 + (1L << 24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+
+ long carry3 = (h3 + (1L << 24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry7 = (h7 + (1L << 24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ long carry9 = (h9 + (1L << 24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+
+ carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Multiplay a field element by 121666
+ /// </summary>
+ public static void Multiply121666(ref FieldElement output, ref FieldElement a)
+ {
+ long h0 = (long)a.x0 * 121666L;
+ long h1 = (long)a.x1 * 121666L;
+ long h2 = (long)a.x2 * 121666L;
+ long h3 = (long)a.x3 * 121666L;
+ long h4 = (long)a.x4 * 121666L;
+ long h5 = (long)a.x5 * 121666L;
+ long h6 = (long)a.x6 * 121666L;
+ long h7 = (long)a.x7 * 121666L;
+ long h8 = (long)a.x8 * 121666L;
+ long h9 = (long)a.x9 * 121666L;
+
+ long carry9 = (h9 + (1L<<24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+ long carry1 = (h1 + (1L<<24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry3 = (h3 + (1L<<24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry5 = (h5 + (1L<<24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+ long carry7 = (h7 + (1L<<24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Invert a field element
+ /// </summary>
+ public static void Invert(ref FieldElement output, ref FieldElement a)
+ {
+ FieldElement t0 = new FieldElement();
+ Square(ref t0, ref a);
+
+ FieldElement t1 = new FieldElement();
+ Square(ref t1, ref t0);
+ Square(ref t1, ref t1);
+
+ FieldElement t2= new FieldElement();
+ Multiply(ref t1, ref a, ref t1);
+ Multiply(ref t0, ref t0, ref t1);
+ Square(ref t2, ref t0);
+ //Square(ref t2, ref t2);
+
+ Multiply(ref t1, ref t1, ref t2);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 5; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 10; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ FieldElement t3 = new FieldElement();
+ Multiply(ref t2, ref t2, ref t1);
+ Square(ref t3, ref t2);
+ for (int ii = 1; ii < 20; ++ii)
+ {
+ Square(ref t3, ref t3);
+ }
+
+ Multiply(ref t2, ref t3, ref t2);
+ Square(ref t2, ref t2);
+ for (int ii = 1; ii < 10; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 50; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t2, ref t2, ref t1);
+ Square(ref t3, ref t2);
+ for (int ii = 1; ii < 100; ++ii)
+ {
+ Square(ref t3, ref t3);
+ }
+
+ Multiply(ref t2, ref t3, ref t2);
+ Square(ref t2, ref t2);
+ for (int ii = 1; ii < 50; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t1, ref t1);
+ for (int ii = 1; ii < 5; ++ii)
+ {
+ Square(ref t1, ref t1);
+ }
+
+ Multiply(ref output, ref t1, ref t0);
+ }
+
+ /// <summary>
+ /// Swaps `a` and `b` if `swap` is 1
+ /// </summary>
+ public static void ConditionalSwap(ref FieldElement a, ref FieldElement b, int swap)
+ {
+ Debug.Assert(swap == 0 || swap == 1);
+ swap = -swap;
+
+ FieldElement temp = new FieldElement
+ {
+ x0 = swap & (a.x0 ^ b.x0),
+ x1 = swap & (a.x1 ^ b.x1),
+ x2 = swap & (a.x2 ^ b.x2),
+ x3 = swap & (a.x3 ^ b.x3),
+ x4 = swap & (a.x4 ^ b.x4),
+ x5 = swap & (a.x5 ^ b.x5),
+ x6 = swap & (a.x6 ^ b.x6),
+ x7 = swap & (a.x7 ^ b.x7),
+ x8 = swap & (a.x8 ^ b.x8),
+ x9 = swap & (a.x9 ^ b.x9),
+ };
+
+ a.x0 ^= temp.x0;
+ a.x1 ^= temp.x1;
+ a.x2 ^= temp.x2;
+ a.x3 ^= temp.x3;
+ a.x4 ^= temp.x4;
+ a.x5 ^= temp.x5;
+ a.x6 ^= temp.x6;
+ a.x7 ^= temp.x7;
+ a.x8 ^= temp.x8;
+ a.x9 ^= temp.x9;
+
+ b.x0 ^= temp.x0;
+ b.x1 ^= temp.x1;
+ b.x2 ^= temp.x2;
+ b.x3 ^= temp.x3;
+ b.x4 ^= temp.x4;
+ b.x5 ^= temp.x5;
+ b.x6 ^= temp.x6;
+ b.x7 ^= temp.x7;
+ b.x8 ^= temp.x8;
+ b.x9 ^= temp.x9;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs
new file mode 100644
index 0000000..35609fc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ public struct DataReceivedEventArgs
+ {
+ public readonly Connection Sender;
+
+ /// <summary>
+ /// The bytes received from the client.
+ /// </summary>
+ public readonly MessageReader Message;
+
+ /// <summary>
+ /// The <see cref="SendOption"/> the data was sent with.
+ /// </summary>
+ public readonly SendOption SendOption;
+
+ public DataReceivedEventArgs(Connection sender, MessageReader msg, SendOption sendOption)
+ {
+ this.Sender = sender;
+ this.Message = msg;
+ this.SendOption = sendOption;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs
new file mode 100644
index 0000000..a7fb05c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs
@@ -0,0 +1,24 @@
+using System;
+
+namespace Hazel
+{
+ public class DisconnectedEventArgs : EventArgs
+ {
+ /// <summary>
+ /// Optional disconnect reason. May be null.
+ /// </summary>
+ public readonly string Reason;
+
+ /// <summary>
+ /// Optional data sent with a disconnect message. May be null.
+ /// You must not recycle this. If you need the message outside of a callback, you should copy it.
+ /// </summary>
+ public readonly MessageReader Message;
+
+ public DisconnectedEventArgs(string reason, MessageReader message)
+ {
+ this.Reason = reason;
+ this.Message = message;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
new file mode 100644
index 0000000..65df39e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
@@ -0,0 +1,147 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// *_AES_128_GCM_* cipher suite
+ /// </summary>
+ public class Aes128GcmRecordProtection: IRecordProtection
+ {
+ private const int ImplicitNonceSize = 4;
+ private const int ExplicitNonceSize = 8;
+
+ private readonly Aes128Gcm serverWriteCipher;
+ private readonly Aes128Gcm clientWriteCipher;
+
+ private readonly ByteSpan serverWriteIV;
+ private readonly ByteSpan clientWriteIV;
+
+ /// <summary>
+ /// Create a new instance of the AES128_GCM record protection
+ /// </summary>
+ /// <param name="masterSecret">Shared secret</param>
+ /// <param name="serverRandom">Server random data</param>
+ /// <param name="clientRandom">Client random data</param>
+ public Aes128GcmRecordProtection(ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ ByteSpan combinedRandom = new byte[serverRandom.Length + clientRandom.Length];
+ serverRandom.CopyTo(combinedRandom);
+ clientRandom.CopyTo(combinedRandom.Slice(serverRandom.Length));
+
+ // Expand master_secret to encryption keys
+ const int ExpandedSize = 0
+ + 0 // mac_key_length
+ + 0 // mac_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + ImplicitNonceSize // fixed_iv_length
+ + ImplicitNonceSize // fixed_iv_length
+ ;
+
+ ByteSpan expandedKey = new byte[ExpandedSize];
+ PrfSha256.ExpandSecret(expandedKey, masterSecret, PrfLabel.KEY_EXPANSION, combinedRandom);
+
+ ByteSpan clientWriteKey = expandedKey.Slice(0, Aes128Gcm.KeySize);
+ ByteSpan serverWriteKey = expandedKey.Slice(Aes128Gcm.KeySize, Aes128Gcm.KeySize);
+ this.clientWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize, ImplicitNonceSize);
+ this.serverWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize + ImplicitNonceSize, ImplicitNonceSize);
+
+ this.serverWriteCipher = new Aes128Gcm(serverWriteKey);
+ this.clientWriteCipher = new Aes128Gcm(clientWriteKey);
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.serverWriteCipher.Dispose();
+ this.clientWriteCipher.Dispose();
+ }
+
+ /// <inheritdoc />
+ private static int GetEncryptedSizeImpl(int dataSize)
+ {
+ return dataSize + Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetEncryptedSize(int dataSize)
+ {
+ return GetEncryptedSizeImpl(dataSize);
+ }
+
+ private static int GetDecryptedSizeImpl(int dataSize)
+ {
+ return dataSize - Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetDecryptedSize(int dataSize)
+ {
+ return GetDecryptedSizeImpl(dataSize);
+ }
+
+ /// <inheritdoc />
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static void EncryptPlaintext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetEncryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)input.Length;
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ cipher.Seal(output, nonce, input, associatedData);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static bool DecryptCiphertext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetDecryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)GetDecryptedSizeImpl(input.Length);
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ return cipher.Open(output, nonce, input, associatedData);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
new file mode 100644
index 0000000..61f41d3
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
@@ -0,0 +1,1424 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+using Hazel.Udp.FewerThreads;
+using Hazel.Crypto;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Listens for new UDP-DTLS connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsConnectionListener : ThreadLimitedUdpConnectionListener
+ {
+ private const int MaxCertFragmentSizeV0 = 1200;
+
+ // Min MTU - UDP+IP header - 1 (for good measure. :))
+ private const int MaxCertFragmentSizeV1 = 576 - 32 - 1;
+
+ /// <summary>
+ /// Current state of handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ ExpectingHello,
+ ExpectingClientKeyExchange,
+ ExpectingChangeCipherSpec,
+ ExpectingFinish
+ }
+
+ /// <summary>
+ /// State to manage the current epoch `N`
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ public IRecordProtection PreviousRecordProtection;
+
+ // Need to keep these around so we can re-transmit our
+ // last handshake record flight
+ public ByteSpan ExpectedClientFinishedVerification;
+ public ByteSpan ServerFinishedVerification;
+ public ulong NextOutgoingSequenceForPreviousEpoch;
+ }
+
+ /// <summary>
+ /// State to manage the transition from the current
+ /// epoch `N` to epoch `N+1`
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+ public CipherSuite SelectedCipherSuite;
+
+ public ulong NextOutgoingSequence;
+
+ public IHandshakeCipherSuite Handshake;
+ public IRecordProtection RecordProtection;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public Sha256Stream VerificationStream;
+
+ public ByteSpan ClientVerification;
+ public ByteSpan ServerVerification;
+
+ }
+
+ /// <summary>
+ /// Per-peer state
+ /// </summary>
+ sealed class PeerData : IDisposable
+ {
+ public ushort Epoch;
+ public bool CanHandleApplicationData;
+
+ public HazelDtlsSessionInfo Session;
+
+ public CurrentEpoch CurrentEpoch;
+ public NextEpoch NextEpoch;
+
+ public ConnectionId ConnectionId;
+
+ public readonly List<ByteSpan> QueuedApplicationDataMessage = new List<ByteSpan>();
+ public readonly ConcurrentBag<MessageReader> ApplicationData = new ConcurrentBag<MessageReader>();
+ public readonly ProtocolVersion ProtocolVersion;
+
+ public DateTime StartOfNegotiation;
+
+ public PeerData(ConnectionId connectionId, ulong nextExpectedSequenceNumber, ProtocolVersion protocolVersion)
+ {
+ ByteSpan block = new byte[2 * Finished.Size];
+ this.CurrentEpoch.ServerFinishedVerification = block.Slice(0, Finished.Size);
+ this.CurrentEpoch.ExpectedClientFinishedVerification = block.Slice(Finished.Size, Finished.Size);
+ this.ProtocolVersion = protocolVersion;
+
+ ResetPeer(connectionId, nextExpectedSequenceNumber);
+ }
+
+ public void ResetPeer(ConnectionId connectionId, ulong nextExpectedSequenceNumber)
+ {
+ Dispose();
+
+ this.Epoch = 0;
+ this.CanHandleApplicationData = false;
+ this.QueuedApplicationDataMessage.Clear();
+
+ this.CurrentEpoch.NextOutgoingSequence = 2; // Account for our ClientHelloVerify
+ this.CurrentEpoch.NextExpectedSequence = nextExpectedSequenceNumber;
+ this.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.CurrentEpoch.RecordProtection = NullRecordProtection.Instance;
+ this.CurrentEpoch.PreviousRecordProtection = null;
+ this.CurrentEpoch.ServerFinishedVerification.SecureClear();
+ this.CurrentEpoch.ExpectedClientFinishedVerification.SecureClear();
+
+ this.NextEpoch.State = HandshakeState.ExpectingHello;
+ this.NextEpoch.RecordProtection = null;
+ this.NextEpoch.Handshake = null;
+ this.NextEpoch.ClientRandom = new byte[Random.Size];
+ this.NextEpoch.ServerRandom = new byte[Random.Size];
+ this.NextEpoch.VerificationStream = new Sha256Stream();
+ this.NextEpoch.ClientVerification = new byte[Finished.Size];
+ this.NextEpoch.ServerVerification = new byte[Finished.Size];
+
+ this.ConnectionId = connectionId;
+
+ this.StartOfNegotiation = DateTime.UtcNow;
+ }
+
+ public void Dispose()
+ {
+ this.CurrentEpoch.RecordProtection?.Dispose();
+ this.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ this.NextEpoch.RecordProtection?.Dispose();
+ this.NextEpoch.Handshake?.Dispose();
+ this.NextEpoch.VerificationStream?.Dispose();
+
+ while (this.ApplicationData.TryTake(out var msg))
+ {
+ try
+ {
+ msg.Recycle();
+ }
+ catch { }
+ }
+ }
+ }
+
+ private RandomNumberGenerator random;
+
+ // Private key component of certificate's public key
+ private ByteSpan encodedCertificate;
+ private RSA certificatePrivateKey;
+
+ // HMAC key to validate ClientHello cookie
+ private ThreadedHmacHelper hmacHelper;
+ private HMAC CurrentCookieHmac {
+ get
+ {
+ return hmacHelper.GetCurrentCookieHmacsForThread();
+ }
+ }
+ private HMAC PreviousCookieHmac
+ {
+ get
+ {
+ return hmacHelper.GetPreviousCookieHmacsForThread();
+ }
+ }
+
+ private ConcurrentStack<ConnectionId> staleConnections = new ConcurrentStack<ConnectionId>();
+ private readonly ConcurrentDictionary<IPEndPoint, PeerData> existingPeers = new ConcurrentDictionary<IPEndPoint, PeerData>();
+ public int PeerCount => this.existingPeers.Count;
+
+ // TODO: Move these into an DtlsErrorStatistics class
+ public int NonPeerNonHelloPacketsDropped;
+ public int NonVerifiedFinishedHandshake;
+ public int NonPeerVerifyHelloRequests;
+ public int PeerVerifyHelloRequests;
+
+ private int connectionSerial_unsafe = 0;
+
+ private Timer staleConnectionUpkeep;
+
+ /// <summary>
+ /// Create a new instance of the DTLS listener
+ /// </summary>
+ /// <param name="numWorkers"></param>
+ /// <param name="endPoint"></param>
+ /// <param name="logger"></param>
+ /// <param name="ipMode"></param>
+ public DtlsConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ : base(numWorkers, endPoint, logger, ipMode)
+ {
+ this.random = RandomNumberGenerator.Create();
+
+ this.staleConnectionUpkeep = new Timer(this.HandleStaleConnections, null, 2500, 1000);
+ this.hmacHelper = new ThreadedHmacHelper(logger);
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ this.staleConnectionUpkeep.Dispose();
+
+ this.random?.Dispose();
+ this.random = null;
+
+ this.hmacHelper?.Dispose();
+ this.hmacHelper = null;
+
+ foreach (var pair in this.existingPeers)
+ {
+ pair.Value.Dispose();
+ }
+ this.existingPeers.Clear();
+ }
+
+ /// <summary>
+ /// Set the certificate key pair for the listener
+ /// </summary>
+ /// <param name="certificate">Certificate for the server</param>
+ public void SetCertificate(X509Certificate2 certificate)
+ {
+ if (!certificate.HasPrivateKey)
+ {
+ throw new ArgumentException("Certificate must have a private key attached", nameof(certificate));
+ }
+
+ RSA privateKey = certificate.GetRSAPrivateKey();
+ if (privateKey == null)
+ {
+ throw new ArgumentException("Certificate must be signed by an RSA key", nameof(certificate));
+ }
+
+ this.certificatePrivateKey?.Dispose();
+ this.certificatePrivateKey = privateKey;
+
+ this.encodedCertificate = Certificate.Encode(certificate);
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network.
+ ///
+ /// This is primarily a wrapper around ProcessIncomingMessage
+ /// to ensure `reader.Recycle()` is always called
+ /// </summary>
+ protected override void ReadCallback(MessageReader reader, IPEndPoint peerAddress, ConnectionId connectionId)
+ {
+ try
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ this.ProcessIncomingMessage(message, peerAddress);
+ }
+ finally
+ {
+ reader.Recycle();
+ }
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network
+ /// </summary>
+ private void ProcessIncomingMessage(ByteSpan message, IPEndPoint peerAddress)
+ {
+ PeerData peer = null;
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ lock (this.existingPeers)
+ {
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ HandleNonPeerRecord(message, peerAddress);
+ return;
+ }
+ }
+ }
+
+ ConnectionId peerConnectionId;
+
+ lock (peer)
+ {
+ peerConnectionId = peer.ConnectionId;
+
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (message.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, peer.ProtocolVersion, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}` Length({record.Length}) AvailableBytes({message.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = message.Slice(0, record.Length);
+ message = message.Slice(record.Length);
+
+ // Early-out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && !peer.CanHandleApplicationData)
+ {
+ this.Logger.WriteInfo($"Dropping ApplicationData record from `{peerAddress}` Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != peer.Epoch)
+ {
+ // Handle existing client negotiating a new connection
+ if (record.Epoch == 0 && record.ContentType == ContentType.Handshake)
+ {
+ ByteSpan handshakePayload = recordPayload;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ handshakePayload = handshakePayload.Slice(Handshake.Size);
+
+ if (handshake.FragmentOffset != 0 || handshake.Length != handshake.FragmentLength)
+ {
+ this.Logger.WriteError($"Dropping fragmented re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+ this.Logger.WriteVerbose($"Dropping non-ClientHello re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshakePayload.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`: Length({handshake.Length}) AvailableBytes({handshakePayload.Length})");
+ }
+
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, recordPayload, handshakePayload))
+ {
+ return;
+ }
+ continue;
+ }
+
+ this.Logger.WriteVerbose($"Dropping bad-epoch record from `{peerAddress}` RecordEpoch({record.Epoch}) CurrentEpoch({peer.Epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(peer.CurrentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < peer.CurrentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.Logger.WriteInfo($"Dropping too-old record from `{peerAddress}` Sequence({record.SequenceNumber}) Expected({peer.CurrentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((peer.CurrentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.Logger.WriteInfo($"Dropping duplicate record from `{peerAddress}`");
+ continue;
+ }
+ }
+
+ // Validate record authenticity
+ int decryptedSize = peer.CurrentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ if (decryptedSize < 0)
+ {
+ this.Logger.WriteInfo($"Dropping malformed record: Length {recordPayload.Length} Decrypted length: {decryptedSize}");
+ continue;
+ }
+
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ if (!peer.CurrentEpoch.RecordProtection.DecryptCiphertextFromClient(decryptedPayload, recordPayload, ref record))
+ {
+ this.Logger.WriteVerbose($"Dropping non-authentic {record.ContentType} record from `{peerAddress}`");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update our squence number bookeeping
+ if (record.SequenceNumber >= peer.CurrentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - peer.CurrentEpoch.NextExpectedSequence);
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ peer.CurrentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.Logger.WriteVerbose($"Record type {record.ContentType} ({peer.NextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.Logger.WriteError($"Dropping unexpected ChangeChiperSpec record from `{peerAddress}` State({peer.NextEpoch.State})");
+ break;
+ }
+ else if (peer.NextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+
+ this.Logger.WriteError($"Dropping ChangeCipherSpec message from `{peerAddress}`: No pending record protection");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed ChangeCipherSpec message from `{peerAddress}`");
+ break;
+ }
+
+ // Migrate to the next epoch
+ peer.Epoch = peer.NextEpoch.Epoch;
+ peer.CanHandleApplicationData = false; // Need a Finished message
+ peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch = peer.CurrentEpoch.NextOutgoingSequence;
+ peer.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ peer.CurrentEpoch.PreviousRecordProtection = peer.CurrentEpoch.RecordProtection;
+ peer.CurrentEpoch.RecordProtection = peer.NextEpoch.RecordProtection;
+ peer.CurrentEpoch.NextOutgoingSequence = 1;
+ peer.CurrentEpoch.NextExpectedSequence = 1;
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ peer.NextEpoch.ClientVerification.CopyTo(peer.CurrentEpoch.ExpectedClientFinishedVerification);
+ peer.NextEpoch.ServerVerification.CopyTo(peer.CurrentEpoch.ServerFinishedVerification);
+
+ peer.NextEpoch.State = HandshakeState.ExpectingHello;
+ peer.NextEpoch.Handshake?.Dispose();
+ peer.NextEpoch.Handshake = null;
+ peer.NextEpoch.NextOutgoingSequence = 1;
+ peer.NextEpoch.RecordProtection = null;
+ peer.NextEpoch.VerificationStream.Reset();
+ peer.NextEpoch.ClientVerification.SecureClear();
+ peer.NextEpoch.ServerVerification.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.Logger.WriteError($"Dropping unsupported Alert record from `{peerAddress}`");
+ break;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(peer, peerAddress, ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ peer.ApplicationData.Add(reader);
+ break;
+ }
+ }
+ }
+
+ // The peer lock must be exited before leaving the DtlsConnectionListener context to prevent deadlocks
+ // because ApplicationData processing may reenter this context
+ while (peer.ApplicationData.TryTake(out var appMsg))
+ {
+ base.ReadCallback(appMsg, peerAddress, peerConnectionId);
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer's network address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(PeerData peer, IPEndPoint peerAddress, ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple handshake payloads
+ while (message.Length > 0)
+ {
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (message.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ ByteSpan payload = message.Slice(0, (int)message.Length);
+ message = message.Slice((int)handshake.Length);
+ originalMessage = originalMessage.Slice(0, Handshake.Size + (int)handshake.Length);
+
+ // We do not support fragmented handshake messages
+ // from the client
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping fragmented Handshake message from `{peerAddress}` Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ ByteSpan packet;
+ ByteSpan writer;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"Received handshake {handshake.MessageType} ({peer.NextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.ClientHello:
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, originalMessage, payload))
+ {
+ return false;
+ }
+ break;
+
+ case HandshakeType.ClientKeyExchange:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ this.Logger.WriteError($"Dropping unexpected ClientKeyExchange message form `{peerAddress}` State({peer.NextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 5)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientKeyExchange message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[peer.NextEpoch.Handshake.SharedKeySize()];
+ if (!peer.NextEpoch.Handshake.VerifyClientMessageAndGenerateSharedKey(sharedSecret, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientKeyExchange message from `{peerAddress}`");
+ return false;
+ }
+
+ // Record incoming ClientKeyExchange message
+ // to verification stream
+ peer.NextEpoch.VerificationStream.AddData(originalMessage);
+
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ peer.NextEpoch.ClientRandom.CopyTo(randomSeed);
+ peer.NextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create the record protection for the upcoming epoch
+ switch (peer.NextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ peer.NextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , peer.NextEpoch.ServerRandom
+ , peer.NextEpoch.ClientRandom);
+ break;
+
+ default:
+ Debug.Assert(false, $"How did we agree to a cipher suite {peer.NextEpoch.SelectedCipherSuite} we can't create?");
+ this.Logger.WriteError($"Dropping ClientKeyExchange message from `{peerAddress}` Unsuppored cipher suite");
+ return false;
+ }
+
+ // Generate verification signatures
+ ByteSpan handshakeStreamHash = new byte[Sha256Stream.DigestSize];
+ peer.NextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeStreamHash);
+
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ClientVerification
+ , masterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeStreamHash
+ );
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ServerVerification
+ , masterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeStreamHash
+ );
+
+
+ // Update handshake state
+ masterSecret.SecureClear();
+ peer.NextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ break;
+
+ case HandshakeType.Finished:
+ // Unlike other handshake messages, this is
+ // for the current epoch - not the next epoch
+
+ // Cannot process a Finished message for
+ // epoch 0
+ if (peer.Epoch == 0)
+ {
+ this.Logger.WriteError($"Dropping Finished message for 0-epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message when we
+ // are negotiating the next epoch
+ else if (peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ this.Logger.WriteError($"Dropping Finished message while negotiating new epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message without
+ // verify data
+ else if (peer.CurrentEpoch.ExpectedClientFinishedVerification.Length != Finished.Size || peer.CurrentEpoch.ServerFinishedVerification.Length != Finished.Size)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch without verify data?");
+
+ this.Logger.WriteError($"Dropping Finished message (no verify data) from `{peerAddress}`");
+ return false;
+ }
+ // Cannot process a Finished message without
+ // record protection for the previous epoch
+ else if (peer.CurrentEpoch.PreviousRecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch with record protection for the previous epoch?");
+
+ this.Logger.WriteError($"Dropping Finished message from `{peerAddress}`: No previous epoch record protection");
+ return false;
+ }
+
+ // Verify message sequence
+ if (handshake.MessageSequence != 6)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence Finished message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the client has the correct
+ // handshake sequence
+ if (payload.Length != Finished.Size)
+ {
+ this.Logger.WriteError($"Dropping malformed Finished message from `{peerAddress}`");
+ return false;
+ }
+ else if (1 != Crypto.Const.ConstantCompareSpans(payload, peer.CurrentEpoch.ExpectedClientFinishedVerification))
+ {
+
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-verified Finished Handshake from `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonVerifiedFinishedHandshake);
+#endif
+
+ // Abort the connection here
+ //
+ // The client is either broken, or
+ // doen not agree on our epoch settings.
+ //
+ // Either way, there is not a feasible
+ // way to progress the connection.
+ MarkConnectionAsStale(peer.ConnectionId);
+ this.existingPeers.TryRemove(peerAddress, out _);
+
+ return false;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Describe our ChangeCipherSpec+Finished
+ Handshake outgoingHandshake = new Handshake();
+ outgoingHandshake.MessageType = HandshakeType.Finished;
+ outgoingHandshake.Length = Finished.Size;
+ outgoingHandshake.MessageSequence = 7;
+ outgoingHandshake.FragmentOffset = 0;
+ outgoingHandshake.FragmentLength = outgoingHandshake.Length;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = protocolVersion;
+ changeCipherSpecRecord.Epoch = (ushort)(peer.Epoch - 1);
+ changeCipherSpecRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+ changeCipherSpecRecord.Length = (ushort)peer.CurrentEpoch.PreviousRecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ int plaintextFinishedPayloadSize = Handshake.Size + (int)outgoingHandshake.Length;
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = protocolVersion;
+ finishedRecord.Epoch = peer.Epoch;
+ finishedRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(plaintextFinishedPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the flight into wire format
+ packet = new byte[Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length];
+ writer = packet;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+
+ ByteSpan startOfFinishedRecord = packet.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ outgoingHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.CurrentEpoch.ServerFinishedVerification.CopyTo(writer);
+
+ // Protect the ChangeChipherSpec record
+ peer.CurrentEpoch.PreviousRecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, changeCipherSpecRecord.Length),
+ packet.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished Handshake record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, plaintextFinishedPayloadSize),
+ ref finishedRecord
+ );
+
+ // Current epoch can now handle application data
+ peer.CanHandleApplicationData = true;
+
+ base.QueueRawData(packet, peerAddress);
+ break;
+
+ // Drop messages that we do not support
+ case HandshakeType.CertificateVerify:
+ this.Logger.WriteError($"Dropping unsupported Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+
+ // Drop messages that originate from the server
+ case HandshakeType.HelloRequest:
+ case HandshakeType.ServerHello:
+ case HandshakeType.HelloVerifyRequest:
+ case HandshakeType.Certificate:
+ case HandshakeType.ServerKeyExchange:
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.ServerHelloDone:
+ this.Logger.WriteError($"Dropping server Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle a ClientHello message for a peer
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="handshake">Parent Handshake header</param>
+ /// <param name="payload">Handshake payload</param>
+ private bool HandleClientHello(PeerData peer, IPEndPoint peerAddress, ref Record record, ref Handshake handshake, ByteSpan originalMessage, ByteSpan payload)
+ {
+ // Verify message sequence
+ if (handshake.MessageSequence != 0)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientHello from `{peerAddress}` MessageSequence({handshake.MessageSequence})`");
+ return true;
+ }
+
+ // Make sure we can handle a ClientHello message
+ if (peer.NextEpoch.State != HandshakeState.ExpectingHello && peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ // Always handle ClientHello for epoch 0
+ if (record.Epoch != 0)
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peer}` Not expecting ClientHello");
+ return true;
+ }
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+ if (!ClientHello.Parse(out ClientHello clientHello, protocolVersion, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ // Find an acceptable cipher suite we can use
+ CipherSuite selectedCipherSuite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256;
+ if (!clientHello.ContainsCipherSuite(CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256) || !clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` No compatible cipher suite");
+ return false;
+ }
+
+ // If this message was not signed by us,
+ // request a signed message before doing anything else
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+ ulong outgoingSequence = 1;
+ IRecordProtection recordProtection = NullRecordProtection.Instance;
+ if (record.Epoch != 0)
+ {
+ outgoingSequence = peer.CurrentEpoch.NextExpectedSequence;
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ recordProtection = peer.CurrentEpoch.RecordProtection;
+ }
+
+#if DEBUG
+ this.Logger.WriteError($"Sending HelloVerifyRequest to peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.PeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, outgoingSequence, record.Epoch, recordProtection, protocolVersion);
+ return true;
+ }
+ }
+
+ // Client is initiating a brand new connection. We need
+ // to destroy the existing connection and establish a
+ // new session.
+ if (record.Epoch == 0 && peer.Epoch != 0)
+ {
+ ConnectionId oldConnectionId = peer.ConnectionId;
+ peer.ResetPeer(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1);
+
+ // Inform the parent layer that the existing
+ // connection should be abandoned.
+ MarkConnectionAsStale(oldConnectionId);
+ }
+
+ // Determine if this is an original message, or a retransmission
+ bool recordMessagesForVerifyData = false;
+ if (peer.NextEpoch.State == HandshakeState.ExpectingHello)
+ {
+ // Create our handhake cipher suite
+ IHandshakeCipherSuite handshakeCipherSuite = null;
+ switch (selectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ if (clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ handshakeCipherSuite = new X25519EcdheRsaSha256(this.random);
+ }
+ else
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 cipher suite");
+ return false;
+ }
+
+ break;
+
+ default:
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create handshake cipher suite");
+ return false;
+ }
+
+ peer.Session = clientHello.Session;
+
+ // Update the state of our epoch transition
+ peer.NextEpoch.Epoch = (ushort)(record.Epoch + 1);
+ peer.NextEpoch.State = HandshakeState.ExpectingClientKeyExchange;
+ peer.NextEpoch.SelectedCipherSuite = selectedCipherSuite;
+ peer.NextEpoch.Handshake = handshakeCipherSuite;
+ clientHello.Random.CopyTo(peer.NextEpoch.ClientRandom);
+ peer.NextEpoch.ServerRandom.FillWithRandom(this.random);
+ recordMessagesForVerifyData = true;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"ClientRandom: {peer.NextEpoch.ClientRandom} ServerRandom: {peer.NextEpoch.ServerRandom}");
+#endif
+
+ // Copy the original ClientHello
+ // handshake to our verification stream
+ peer.NextEpoch.VerificationStream.AddData(
+ originalMessage.Slice(
+ 0
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+ }
+
+ // The initial record flight from the server
+ // contains the following Handshake messages:
+ // * ServerHello
+ // * Certificate
+ // * ServerKeyExchange
+ // * ServerHelloDone
+ //
+ // The Certificate message is almost always
+ // too large to fit into a single datagram,
+ // so it is pre-fragmented
+ // (see `SetCertificates`). Therefore, we
+ // need to send multiple record packets for
+ // this flight.
+ //
+ // The first record contains the ServerHello
+ // handshake message, as well as the first
+ // portion of the Certificate message.
+ //
+ // We then send a record packet until the
+ // entire Certificate message has been sent
+ // to the client.
+ //
+ // The final record packet contains the
+ // ServerKeyExchange and the ServerHelloDone
+ // messages.
+
+ // Describe first record of the flight
+ ServerHello serverHello = new ServerHello();
+ serverHello.ServerProtocolVersion = protocolVersion;
+ serverHello.Random = peer.NextEpoch.ServerRandom;
+ serverHello.CipherSuite = selectedCipherSuite;
+
+ Handshake serverHelloHandshake = new Handshake();
+ serverHelloHandshake.MessageType = HandshakeType.ServerHello;
+ serverHelloHandshake.Length = ServerHello.MinSize;
+ serverHelloHandshake.MessageSequence = 1;
+ serverHelloHandshake.FragmentOffset = 0;
+ serverHelloHandshake.FragmentLength = serverHelloHandshake.Length;
+
+ int maxCertFragmentSize = peer.Session.Version == 0 ? MaxCertFragmentSizeV0 : MaxCertFragmentSizeV1;
+
+ // The first certificate data needs to leave room for
+ // * Record header
+ // * ServerHello header
+ // * ServerHello payload
+ // * Certificate header
+
+ var certificateData = this.encodedCertificate;
+ int initialCertPadding = Record.Size + Handshake.Size + serverHello.Size + Handshake.Size;
+ int certInitialFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - initialCertPadding);
+
+ Handshake certificateHandshake = new Handshake();
+ certificateHandshake.MessageType = HandshakeType.Certificate;
+ certificateHandshake.Length = (uint)certificateData.Length;
+ certificateHandshake.MessageSequence = 2;
+ certificateHandshake.FragmentOffset = 0;
+ certificateHandshake.FragmentLength = (uint)certInitialFragmentSize;
+
+ int initialRecordPayloadSize = 0
+ + Handshake.Size + serverHello.Size
+ + Handshake.Size + (int)certificateHandshake.FragmentLength
+ ;
+ Record initialRecord = new Record();
+ initialRecord.ContentType = ContentType.Handshake;
+ initialRecord.ProtocolVersion = protocolVersion;
+ initialRecord.Epoch = peer.Epoch;
+ initialRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ initialRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(initialRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert initial record of the flight to
+ // wire format
+ ByteSpan packet = new byte[Record.Size + initialRecord.Length];
+ ByteSpan writer = packet;
+ initialRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certInitialFragmentSize).CopyTo(writer);
+ certificateData = certificateData.Slice(certInitialFragmentSize);
+
+ // Protect initial record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, initialRecord.Length),
+ packet.Slice(Record.Size, initialRecordPayloadSize),
+ ref initialRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ Handshake fullCeritficateHandshake = certificateHandshake;
+ fullCeritficateHandshake.FragmentLength = fullCeritficateHandshake.Length;
+
+ packet = new byte[Handshake.Size + ServerHello.MinSize + Handshake.Size];
+ writer = packet;
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ fullCeritficateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ peer.NextEpoch.VerificationStream.AddData(packet);
+ peer.NextEpoch.VerificationStream.AddData(this.encodedCertificate);
+ }
+
+ // Process additional certificate records
+ // Subsequent certificate data needs to leave room for
+ // * Record header
+ // * Certificate header
+ const int CertPadding = Record.Size + Handshake.Size;
+ while (certificateData.Length > 0)
+ {
+ int certFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - CertPadding);
+
+ certificateHandshake.FragmentOffset += certificateHandshake.FragmentLength;
+ certificateHandshake.FragmentLength = (uint)certFragmentSize;
+
+ int additionalRecordPayloadSize = Handshake.Size + (int)certificateHandshake.FragmentLength;
+ Record additionalRecord = new Record();
+ additionalRecord.ContentType = ContentType.Handshake;
+ additionalRecord.ProtocolVersion = protocolVersion;
+ additionalRecord.Epoch = peer.Epoch;
+ additionalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ additionalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(additionalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert record to wire format
+ packet = new byte[Record.Size + additionalRecord.Length];
+ writer = packet;
+ additionalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certFragmentSize).CopyTo(writer);
+
+ certificateData = certificateData.Slice(certFragmentSize);
+
+ // Protect record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, additionalRecord.Length),
+ packet.Slice(Record.Size, additionalRecordPayloadSize),
+ ref additionalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ // Describe final record of the flight
+ Handshake serverKeyExchangeHandshake = new Handshake();
+ serverKeyExchangeHandshake.MessageType = HandshakeType.ServerKeyExchange;
+ serverKeyExchangeHandshake.Length = (uint)peer.NextEpoch.Handshake.CalculateServerMessageSize(this.certificatePrivateKey);
+ serverKeyExchangeHandshake.MessageSequence = 3;
+ serverKeyExchangeHandshake.FragmentOffset = 0;
+ serverKeyExchangeHandshake.FragmentLength = serverKeyExchangeHandshake.Length;
+
+ Handshake serverHelloDoneHandshake = new Handshake();
+ serverHelloDoneHandshake.MessageType = HandshakeType.ServerHelloDone;
+ serverHelloDoneHandshake.Length = 0;
+ serverHelloDoneHandshake.MessageSequence = 4;
+ serverHelloDoneHandshake.FragmentOffset = 0;
+ serverHelloDoneHandshake.FragmentLength = 0;
+
+ int finalRecordPayloadSize = 0
+ + Handshake.Size + (int)serverKeyExchangeHandshake.Length
+ + Handshake.Size + (int)serverHelloDoneHandshake.Length
+ ;
+ Record finalRecord = new Record();
+ finalRecord.ContentType = ContentType.Handshake;
+ finalRecord.ProtocolVersion = protocolVersion;
+ finalRecord.Epoch = peer.Epoch;
+ finalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(finalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert final record of the flight to wire
+ // format
+ packet = new byte[Record.Size + finalRecord.Length];
+ writer = packet;
+ finalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverKeyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.NextEpoch.Handshake.EncodeServerKeyExchangeMessage(writer, this.certificatePrivateKey);
+ writer = writer.Slice((int)serverKeyExchangeHandshake.Length);
+ serverHelloDoneHandshake.Encode(writer);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ peer.NextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ packet.Offset + Record.Size
+ , finalRecordPayloadSize
+ )
+ );
+ }
+
+ // Protect final record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, finalRecord.Length),
+ packet.Slice(Record.Size, finalRecordPayloadSize),
+ ref finalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle an incoming packet that is not tied to an existing peer
+ /// </summary>
+ /// <param name="message">Incoming datagram</param>
+ /// <param name="peerAddress">Originating address</param>
+ private void HandleNonPeerRecord(ByteSpan message, IPEndPoint peerAddress)
+ {
+ Record record;
+ if (!Record.Parse(out record, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from non-peer `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ // The protocol only supports receiving a single record
+ // from a non-peer.
+ if (record.Length != message.Length)
+ {
+ // NOTE(mendsley): This isn't always fatal.
+ // However, this is an indication that something
+ // fishy is going on. In the best case, there's a
+ // bug on the client or in the UDP stack (some
+ // stacks don't both to verify the checksum). In the
+ // worst case we're dealing with a malicious actor.
+ // In the malicious case, we'll end up dropping the
+ // connection later in the process.
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteInfo($"Dropping bad record from non-peer `{peerAddress}`. Msg length {message.Length} < {record.Length}");
+ return;
+ }
+ }
+
+ // We only accept zero-epoch records from non-peers
+ if (record.Epoch != 0)
+ {
+ ///NOTE(mendsley): Not logging anything here, as
+ /// this could easily be latent data arriving from a
+ /// recently disconnected peer.
+ return;
+ }
+
+ // We only accept Handshake protocol messages from non-peers
+ if (record.ContentType != ContentType.Handshake)
+ {
+ this.Logger.WriteError($"Dropping non-handhsake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed handshake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // We only accept ClientHello messages from non-peers
+ if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-ClientHello ({handshake.MessageType}) message from non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerNonHelloPacketsDropped);
+#endif
+ return;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (!ClientHello.Parse(out ClientHello clientHello, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // If this ClientHello is not signed by us, request the
+ // client send us a signed message
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+#if DEBUG
+ this.Logger.WriteVerbose($"Sending HelloVerifyRequest to non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, 1, 0, NullRecordProtection.Instance, clientHello.ClientProtocolVersion);
+ return;
+ }
+ }
+
+ // Allocate state for the new peer and register it
+ PeerData peer = new PeerData(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1, clientHello.ClientProtocolVersion);
+ this.ProcessHandshake(peer, peerAddress, ref record, originalMessage);
+ this.existingPeers[peerAddress] = peer;
+ }
+
+ //Send a HelloVerifyRequest handshake message to a peer
+ private void SendHelloVerifyRequest(IPEndPoint peerAddress, ulong recordSequence, ushort epoch, IRecordProtection recordProtection, ProtocolVersion protocolVersion)
+ {
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.HelloVerifyRequest;
+ handshake.Length = HelloVerifyRequest.Size;
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ int plaintextPayloadSize = Handshake.Size + (int)handshake.Length;
+
+ Record record = new Record();
+ record.ContentType = ContentType.Handshake;
+ record.ProtocolVersion = protocolVersion;
+ record.Epoch = epoch;
+ record.SequenceNumber = recordSequence;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(plaintextPayloadSize);
+
+ // Encode record to wire format
+ ByteSpan packet = new byte[Record.Size + record.Length];
+ ByteSpan writer = packet;
+ record.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ HelloVerifyRequest.Encode(writer, peerAddress, this.CurrentCookieHmac, protocolVersion);
+
+ // Protect record payload
+ recordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, record.Length),
+ packet.Slice(Record.Size, plaintextPayloadSize),
+ ref record
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ /// <summary>
+ /// Handle a requrest to send a datagram to the network
+ /// </summary>
+ protected override void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint)
+ {
+ PeerData peer;
+ if (!this.existingPeers.TryGetValue(remoteEndPoint, out peer))
+ {
+ // Drop messages if we don't know how to send them
+ return;
+ }
+
+ lock (peer)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ ByteSpan copyOfSpan = new byte[span.Length];
+ span.CopyTo(copyOfSpan);
+
+ peer.QueuedApplicationDataMessage.Add(copyOfSpan);
+ return;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Send any queued application data now
+ for (int ii = 0, nn = peer.QueuedApplicationDataMessage.Count; ii != nn; ++ii)
+ {
+ ByteSpan queuedSpan = peer.QueuedApplicationDataMessage[ii];
+
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(queuedSpan.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ queuedSpan.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, queuedSpan.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ peer.QueuedApplicationDataMessage.Clear();
+
+ {
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(span.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ span.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, span.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ }
+ }
+
+ private void HandleStaleConnections(object _)
+ {
+ TimeSpan maxAge = TimeSpan.FromSeconds(2.5f);
+ DateTime now = DateTime.UtcNow;
+ foreach (KeyValuePair<IPEndPoint, PeerData> kvp in this.existingPeers)
+ {
+ PeerData peer = kvp.Value;
+ lock (peer)
+ {
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ TimeSpan negotiationAge = now - peer.StartOfNegotiation;
+ if (negotiationAge > maxAge)
+ {
+ MarkConnectionAsStale(peer.ConnectionId);
+ }
+ }
+ }
+ }
+
+ ConnectionId connectionId;
+ while (this.staleConnections.TryPop(out connectionId))
+ {
+ ThreadLimitedUdpServerConnection connection;
+ if (this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ connection.Disconnect("Stale Connection", null);
+ }
+ }
+ }
+
+ protected void MarkConnectionAsStale(ConnectionId connectionId)
+ {
+ if (this.allConnections.ContainsKey(connectionId))
+ {
+ this.staleConnections.Push(connectionId);
+ }
+ }
+
+ /// <inheritdoc />
+ internal override void RemovePeerRecord(ConnectionId connectionId)
+ {
+ if (this.existingPeers.TryRemove(connectionId.EndPoint, out var peer))
+ {
+ peer.Dispose();
+ }
+ }
+
+ /// <summary>
+ /// Allocate a new connection id
+ /// </summary>
+ private ConnectionId AllocateConnectionId(IPEndPoint endPoint)
+ {
+ int rawSerialId = Interlocked.Increment(ref this.connectionSerial_unsafe);
+ return ConnectionId.Create(endPoint, rawSerialId);
+ }
+
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
new file mode 100644
index 0000000..4da2051
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
@@ -0,0 +1,1246 @@
+using Hazel.Crypto;
+using Hazel.Udp;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Connects to a UDP-DTLS server
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsUnityConnection : UnityUdpClientConnection
+ {
+ /// <summary>
+ /// Current state of the handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ Initializing,
+
+ ExpectingServerHello,
+ ExpectingCertificate,
+ ExpectingServerKeyExchange,
+ ExpectingServerHelloDone,
+
+ ExpectingChangeCipherSpec,
+ ExpectingFinished,
+
+ Established,
+ }
+
+ /// <summary>
+ /// State data for the current epoch
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ }
+
+ struct FragmentRange
+ {
+ public int Offset;
+ public int Length;
+ }
+
+ /// <summary>
+ /// State data for the next epoch
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+
+ public ulong NextOutgoingSequence;
+
+ public DateTime NegotiationStartTime;
+ public DateTime NextPacketResendTime;
+ public int PacketResendCount;
+
+ public CipherSuite SelectedCipherSuite;
+ public IRecordProtection RecordProtection;
+ public IHandshakeCipherSuite Handshake;
+ public ByteSpan Cookie;
+ public Sha256Stream VerificationStream;
+ public RSA ServerPublicKey;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public ByteSpan MasterSecret;
+ public ByteSpan ServerVerification;
+
+ public List<FragmentRange> CertificateFragments;
+ public ByteSpan CertificatePayload;
+ }
+
+ struct QueuedAppData
+ {
+ public byte[] Bytes;
+ public byte SendOption;
+ public Action AckCallback;
+ }
+
+ private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP;
+
+ internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion;
+
+ private readonly object syncRoot = new object();
+ private readonly RandomNumberGenerator random = RandomNumberGenerator.Create();
+
+ private ushort epoch;
+ private CurrentEpoch currentEpoch;
+ private NextEpoch nextEpoch;
+ private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200);
+
+ private readonly Queue<QueuedAppData> queuedApplicationData = new Queue<QueuedAppData>();
+
+ private X509Certificate2Collection serverCertificates = new X509Certificate2Collection();
+
+ public bool HandshakeComplete
+ {
+ get
+ {
+ lock (this.syncRoot)
+ {
+ return this.nextEpoch.State == HandshakeState.Established;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Create a new instance of the DTLS connection
+ /// </summary>
+ /// <inheritdoc />
+ public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger, remoteEndPoint, ipMode)
+ {
+ this.nextEpoch.ServerRandom = new byte[Random.Size];
+ this.nextEpoch.ClientRandom = new byte[Random.Size];
+ this.nextEpoch.ServerVerification = new byte[Finished.Size];
+ this.nextEpoch.CertificateFragments = new List<FragmentRange>();
+
+ this.ResetConnectionState();
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ }
+ }
+
+ /// <summary>
+ /// Set the list of valid server certificates
+ /// </summary>
+ /// <param name="certificateCollection">
+ /// List of certificates of authentic servers
+ /// </param>
+ public void SetValidServerCertificates(X509Certificate2Collection certificateCollection)
+ {
+ lock (this.syncRoot)
+ {
+ foreach (X509Certificate2 certificate in certificateCollection)
+ {
+ if (!(certificate.PublicKey.Key is RSA))
+ {
+ throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection));
+ }
+ }
+
+ this.serverCertificates = certificateCollection;
+ }
+ }
+
+ /// <summary>
+ /// Set the packet resend timer for handshake messages
+ /// </summary>
+ public void SetHandshakeResendTimeout(TimeSpan timeout)
+ {
+ lock (this.syncRoot)
+ {
+ this.handshakeResendTimeout = timeout;
+ }
+ }
+
+ /// <summary>
+ /// Reset existing connection state
+ /// </summary>
+ private void ResetConnectionState()
+ {
+ this.currentEpoch.NextOutgoingSequence = 1;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.currentEpoch.RecordProtection?.Dispose();
+ this.currentEpoch.RecordProtection = NullRecordProtection.Instance;
+
+ this.nextEpoch.Epoch = 1;
+ this.nextEpoch.State = HandshakeState.Initializing;
+ this.nextEpoch.NextOutgoingSequence = 1;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection?.Dispose();
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Handshake = null;
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream?.Dispose();
+ this.nextEpoch.VerificationStream = new Sha256Stream();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+ this.epoch = 0;
+ while (this.queuedApplicationData.TryDequeue(out _)) ;
+ }
+
+ /// <summary>
+ /// Abort the existing connection and restart the process
+ /// </summary>
+ protected override void RestartConnection()
+ {
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+ this.SendClientHello(isRetransmit: false);
+ }
+
+ base.RestartConnection();
+ }
+
+ /// <inheritdoc />
+ protected override void ResendPacketsIfNeeded()
+ {
+ lock (this.syncRoot)
+ {
+ // Check if we need to resend handshake message
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ DateTime now = DateTime.UtcNow;
+ if (now >= this.nextEpoch.NextPacketResendTime)
+ {
+ double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds;
+ this.nextEpoch.PacketResendCount++;
+
+ if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit)
+ || negotiationDurationMs > this.DisconnectTimeoutMs)
+ {
+ this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms).");
+ }
+ else
+ {
+ switch (this.nextEpoch.State)
+ {
+ case HandshakeState.ExpectingServerHello:
+ case HandshakeState.ExpectingCertificate:
+ case HandshakeState.ExpectingServerKeyExchange:
+ case HandshakeState.ExpectingServerHelloDone:
+ this.SendClientHello(isRetransmit: true);
+ break;
+
+ case HandshakeState.ExpectingChangeCipherSpec:
+ case HandshakeState.ExpectingFinished:
+ this.SendClientKeyExchangeFlight(isRetransmit: true);
+ break;
+
+ case HandshakeState.Established:
+ default:
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ base.ResendPacketsIfNeeded();
+ }
+
+ /// <summary>
+ /// Flush any queued application data packets
+ /// </summary>
+ private void FlushQueuedApplicationData()
+ {
+ while (this.queuedApplicationData.TryDequeue(out var queuedData))
+ {
+ base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback);
+ }
+ }
+
+ /// <summary>
+ /// Request from the application to write data to the DTLS
+ /// stream. If appropriate, returns a byte span to send to
+ /// the wire.
+ /// </summary>
+ /// <param name="bytes">Plaintext bytes to write</param>
+ /// <param name="length">Length of the bytes to write</param>
+ /// <returns>
+ /// Encrypted data to put on the wire if appropriate,
+ /// otherwise an empty span
+ /// </returns>
+ private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length)
+ {
+ lock (this.syncRoot)
+ {
+ Record outgoinRecord = new Record();
+ outgoinRecord.ContentType = ContentType.ApplicationData;
+ outgoinRecord.ProtocolVersion = DtlsVersion;
+ outgoinRecord.Epoch = this.epoch;
+ outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoinRecord.Length];
+ ByteSpan writer = packet;
+ outgoinRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ new ByteSpan(bytes, 0, length).CopyTo(writer);
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoinRecord.Length),
+ packet.Slice(Record.Size, length),
+ ref outgoinRecord
+ );
+
+ return packet;
+ }
+ }
+
+ protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null)
+ {
+ lock (this.syncRoot)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.queuedApplicationData.Enqueue(new QueuedAppData
+ {
+ Bytes = data,
+ SendOption = sendOption,
+ AckCallback = ackCallback
+ });
+
+ return;
+ }
+ }
+
+ base.HandleSend(data, sendOption, ackCallback);
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnectionSync(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected internal override void HandleReceive(MessageReader reader, int bytesReceived)
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ lock (this.syncRoot)
+ {
+ this.HandleReceive(message);
+ }
+
+ reader.Recycle();
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram
+ /// </summary>
+ /// <param name="span">Bytes of the datagram</param>
+ private void HandleReceive(ByteSpan span)
+ {
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (span.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, DtlsVersion, span))
+ {
+ this.logger.WriteError("Dropping malformed record");
+ return;
+ }
+ span = span.Slice(Record.Size);
+
+ if (span.Length < record.Length)
+ {
+ this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = span.Slice(0, record.Length);
+ span = span.Slice(record.Length);
+
+ // Early out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.logger.WriteError("Dropping ApplicationData record. Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != this.epoch)
+ {
+ this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.logger.WriteWarning("Dropping duplicate record");
+ continue;
+ }
+ }
+
+ // Verify record authenticity
+ int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+
+ if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record))
+ {
+ this.logger.WriteError("Dropping non-authentic record");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update out sequence number bookkeeping
+ if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence);
+ this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})");
+ break;
+ }
+ else if (this.nextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.logger.WriteError("Dropping malformed ChangeCipherSpec message");
+ break;
+ }
+
+ // Migrate to the next epoch
+ this.epoch = this.nextEpoch.Epoch;
+ this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection;
+ this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+
+ this.nextEpoch.State = HandshakeState.ExpectingFinished;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.logger.WriteError("Dropping unsupported alert record");
+ continue;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ base.HandleReceive(reader, recordPayload.Length);
+ break;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple Handshake messages
+ while (message.Length > 0)
+ {
+ ByteSpan originalPayload = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.logger.WriteError("Dropping malformed handshake message");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ // Check for fragmented messages
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ // We only support fragmentation on Certificate messages
+ if (handshake.MessageType != HandshakeType.Certificate)
+ {
+ this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ if (message.Length < handshake.FragmentLength)
+ {
+ this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength));
+ message = message.Slice((int)handshake.FragmentLength);
+ }
+ else
+ {
+ if (message.Length < handshake.Length)
+ {
+ this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length));
+ message = message.Slice((int)handshake.Length);
+ }
+
+ ByteSpan payload = originalPayload.Slice(Handshake.Size);
+
+#if DEBUG
+ this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.HelloVerifyRequest:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 0)
+ {
+ this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ HelloVerifyRequest helloVerifyRequest;
+ if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload))
+ {
+ this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ // If the cookie differs, save it and restart the handshake
+ if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length
+ && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1)
+ {
+ this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length];
+ helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie);
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+
+ // We don't need to resend here. We already have the cookie so we already sent it once.
+ this.SendClientHello(isRetransmit: false);
+
+ break;
+
+ case HandshakeType.ServerHello:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 1)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ServerHello serverHello;
+ if (!ServerHello.Parse(out serverHello, payload))
+ {
+ this.logger.WriteError("Dropping malformed ServerHello message");
+ continue;
+ }
+
+ switch (serverHello.CipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random);
+ break;
+
+ default:
+ this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})");
+ continue;
+ }
+
+ // Save server parameters
+ this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite;
+ serverHello.Random.CopyTo(this.nextEpoch.ServerRandom);
+ this.nextEpoch.State = HandshakeState.ExpectingCertificate;
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+#if DEBUG
+ this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}");
+#endif
+
+ // Append ServerHelllo message to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.Certificate:
+ if (this.nextEpoch.State != HandshakeState.ExpectingCertificate)
+ {
+ this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 2)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // If this is a fragmented message
+ if (handshake.FragmentLength != handshake.Length)
+ {
+ if (this.nextEpoch.CertificatePayload.Length != handshake.Length)
+ {
+ this.nextEpoch.CertificatePayload = new byte[handshake.Length];
+ this.nextEpoch.CertificateFragments.Clear();
+ }
+
+ // Should we add this fragment?
+ // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments
+ // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt?
+ // So let's just ignore that and assume that the sender always wants to send the same fragments.
+ if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength))
+ {
+ continue;
+ }
+
+ payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength));
+ this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength });
+ this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => {
+ return lhs.Offset.CompareTo(rhs.Offset);
+ });
+
+ // Have we completed the message?
+ int currentOffset = 0;
+ bool valid = true;
+ foreach (FragmentRange range in this.nextEpoch.CertificateFragments)
+ {
+ if (range.Offset != currentOffset)
+ {
+ valid = false;
+ break;
+ }
+
+ currentOffset += range.Length;
+ }
+
+ if (currentOffset != this.nextEpoch.CertificatePayload.Length)
+ {
+ valid = false;
+ }
+
+ // Still waiting on more fragments?
+ if (!valid)
+ {
+ continue;
+ }
+
+ // Replace the message payload, and continue
+ this.nextEpoch.CertificateFragments.Clear();
+ payload = this.nextEpoch.CertificatePayload;
+ }
+
+ X509Certificate2 certificate;
+ if (!Certificate.Parse(out certificate, payload))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message");
+ continue;
+ }
+
+ // Verify the certificate is authenticate
+ if (!this.serverCertificates.Contains(certificate))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic");
+ continue;
+ }
+
+ RSA publicKey = certificate.PublicKey.Key as RSA;
+ if (publicKey == null)
+ {
+ this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed");
+ continue;
+ }
+
+ // Add the final Certificate message to the verification stream
+ Handshake fullCertificateHandhake = handshake;
+ fullCertificateHandhake.FragmentOffset = 0;
+ fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length;
+
+ ByteSpan serializedCertificateHandshake = new byte[Handshake.Size];
+ fullCertificateHandhake.Encode(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(payload);
+
+ this.nextEpoch.ServerPublicKey = publicKey;
+ this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange;
+ break;
+
+ case HandshakeType.ServerKeyExchange:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (this.nextEpoch.ServerPublicKey == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key");
+ continue;
+ }
+ else if (this.nextEpoch.Handshake == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface");
+ continue;
+ }
+ else if (handshake.MessageSequence != 3)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()];
+ if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey))
+ {
+ this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage");
+ return false;
+ }
+
+ // Generate the session master secret
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ this.nextEpoch.ClientRandom.CopyTo(randomSeed);
+ this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create record protection for the upcoming epoch
+ switch (this.nextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , this.nextEpoch.ServerRandom
+ , this.nextEpoch.ClientRandom
+ );
+ break;
+
+ default:
+ ///NOTE(mendsley): this _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "SeverHello processing already approved this ciphersuite");
+
+ this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection");
+ return false;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone;
+ this.nextEpoch.MasterSecret = masterSecret;
+
+ // Append ServerKeyExchange to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.ServerHelloDone:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 4)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+
+ // Append ServerHelloDone to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+
+ this.SendClientKeyExchangeFlight(isRetransmit: false);
+ break;
+
+ case HandshakeType.Finished:
+ if (this.nextEpoch.State != HandshakeState.ExpectingFinished)
+ {
+ this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (payload.Length != Finished.Size)
+ {
+ this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 7)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the digest from the server
+ if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification))
+ {
+ this.logger.WriteError("Dropping non-verified Finished handshake message");
+ return false;
+ }
+
+ ++this.nextEpoch.Epoch;
+ this.nextEpoch.State = HandshakeState.Established;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+
+ this.FlushQueuedApplicationData();
+ break;
+
+ // Drop messages we do not support
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.HelloRequest:
+ this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})");
+ break;
+
+ // Drop messages that originate from the client
+ case HandshakeType.ClientHello:
+ case HandshakeType.ClientKeyExchange:
+ case HandshakeType.CertificateVerify:
+ this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})");
+ break;
+ }
+ }
+
+ return true;
+ }
+
+ private bool IsFragmentOverlapping(List<FragmentRange> fragments, uint newOffset, uint newLength)
+ {
+ foreach (var frag in fragments)
+ {
+ // New fragment overlaps an existing one
+ if (newOffset <= frag.Offset
+ && frag.Offset < newOffset + newLength)
+ {
+ return true;
+ }
+
+ // Existing fragment overlaps this new one
+ if (frag.Offset <= newOffset
+ && newOffset < frag.Offset + frag.Length)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Send (resend) a ClientHello message to the server
+ /// </summary>
+ protected virtual void SendClientHello(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}");
+#endif
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.ClientProtocolVersion = DtlsVersion;
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion);
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ clientHello.Encode(writer);
+
+ // If this is our first valid attempt at contacting the server:
+ // - Reset our verification stream
+ // - Write ClientHello to the verification stream
+ // - We next expect a ServerHello
+ //
+ // ClientHello+Cookie triggers many sequential packets in response
+ // It's important to make forward progress as the packets may be reordered in-flight
+ // But with enough resends, we will read them all in an appropriate order
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length)
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ }
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected void Test_SendClientHello(Func<ClientHello, ByteSpan, ByteSpan> encodeCallback)
+ {
+ // Reset our verification stream
+ this.nextEpoch.VerificationStream.Reset();
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ writer = encodeCallback(clientHello, writer);
+
+ // Write ClientHello to the verification stream
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ /// <summary>
+ /// Send (resend) the ClientKeyExchange flight
+ /// </summary>
+ /// <param name="isRetransmit">
+ /// True if this is a retransmit of the flight. Otherwise,
+ /// false
+ /// </param>
+ protected virtual void SendClientKeyExchangeFlight(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}");
+#endif
+ if (this.nextEpoch.State == HandshakeState.Established)
+ {
+ return;
+ }
+
+ // Describe our flight
+ Handshake keyExchangeHandshake = new Handshake();
+ keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange;
+ keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize();
+ keyExchangeHandshake.MessageSequence = 5;
+ keyExchangeHandshake.FragmentOffset = 0;
+ keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length;
+
+ Record keyExchangeRecord = new Record();
+ keyExchangeRecord.ContentType = ContentType.Handshake;
+ keyExchangeRecord.ProtocolVersion = DtlsVersion;
+ keyExchangeRecord.Epoch = this.epoch;
+ keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = DtlsVersion;
+ changeCipherSpecRecord.Epoch = this.epoch;
+ changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Handshake finishedHandshake = new Handshake();
+ finishedHandshake.MessageType = HandshakeType.Finished;
+ finishedHandshake.Length = Finished.Size;
+ finishedHandshake.MessageSequence = 6;
+ finishedHandshake.FragmentOffset = 0;
+ finishedHandshake.FragmentLength = finishedHandshake.Length;
+
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = DtlsVersion;
+ finishedRecord.Epoch = this.nextEpoch.Epoch;
+ finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length);
+ ++this.nextEpoch.NextOutgoingSequence;
+
+ // Encode flight to wire format
+ int packetLength = 0
+ + Record.Size + keyExchangeRecord.Length
+ + Record.Size + changeCipherSpecRecord.Length
+ + Record.Size + finishedRecord.Length;
+ ;
+ ByteSpan packet = new byte[packetLength];
+ ByteSpan writer = packet;
+
+ keyExchangeRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ keyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer);
+
+ ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length);
+ writer = startOfChangeCipherSpecRecord;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+ writer = writer.Slice(ChangeCipherSpec.Size);
+
+ ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ finishedHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ // Interject here to writer our client key exchange
+ // message into the verification stream
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)keyExchangeHandshake.Length
+ )
+ );
+ }
+
+ // Calculate the hash of the verification stream
+ ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize];
+ this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash);
+
+ // Expand our master secret into Finished digests for the client and server
+ PrfSha256.ExpandSecret(
+ this.nextEpoch.ServerVerification
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeHash
+ );
+
+ PrfSha256.ExpandSecret(
+ writer.Slice(0, Finished.Size)
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeHash
+ );
+ writer = writer.Slice(Finished.Size);
+
+ // Protect the ClientKeyExchange record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, keyExchangeRecord.Length),
+ packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length),
+ ref keyExchangeRecord
+ );
+
+ // Protect the ChangeCipherSpec record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length),
+ startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished record
+ this.nextEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length),
+ ref finishedRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+#if DEBUG
+ if (DropClientKeyExchangeFlight())
+ {
+ return;
+ }
+#endif
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected virtual bool DropClientKeyExchangeFlight()
+ {
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
new file mode 100644
index 0000000..f840053
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
@@ -0,0 +1,734 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Handshake message type
+ /// </summary>
+ public enum HandshakeType : byte
+ {
+ HelloRequest = 0,
+ ClientHello = 1,
+ ServerHello = 2,
+ HelloVerifyRequest = 3,
+ Certificate = 11,
+ ServerKeyExchange = 12,
+ CertificateRequest = 13,
+ ServerHelloDone = 14,
+ CertificateVerify = 15,
+ ClientKeyExchange = 16,
+ Finished = 20,
+ }
+
+ /// <summary>
+ /// List of cipher suites
+ /// </summary>
+ public enum CipherSuite
+ {
+ TLS_NULL_WITH_NULL_NULL = 0x0000,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F,
+ }
+
+ /// <summary>
+ /// List of compression methods
+ /// </summary>
+ public enum CompressionMethod : byte
+ {
+ Null = 0,
+ }
+
+ /// <summary>
+ /// Extension type
+ /// </summary>
+ public enum ExtensionType : ushort
+ {
+ EllipticCurves = 10,
+ }
+
+ /// <summary>
+ /// Named curves
+ /// </summary>
+ public enum NamedCurve : ushort
+ {
+ Reserved = 0,
+ secp256r1 = 23,
+ x25519 = 29,
+ }
+
+ /// <summary>
+ /// Elliptic curve type
+ /// </summary>
+ public enum ECCurveType : byte
+ {
+ NamedCurve = 3,
+ }
+
+ /// <summary>
+ /// Hash algorithms
+ /// </summary>
+ public enum HashAlgorithm : byte
+ {
+ None = 0,
+ Sha256 = 4,
+ }
+
+ /// <summary>
+ /// Signature algorithms
+ /// </summary>
+ public enum SignatureAlgorithm : byte
+ {
+ Anonymous = 0,
+ RSA = 1,
+ ECDSA = 3,
+ }
+
+ /// <summary>
+ /// Random state for entropy
+ /// </summary>
+ public struct Random
+ {
+ public const int Size = 0
+ + 4 // gmt_unix_time
+ + 28 // random_bytes
+ ;
+ }
+
+ /// <summary>
+ /// Encode/decode handshake protocol header
+ /// </summary>
+ public struct Handshake
+ {
+ public HandshakeType MessageType;
+ public uint Length;
+ public ushort MessageSequence;
+ public uint FragmentOffset;
+ public uint FragmentLength;
+
+ public const int Size = 12;
+
+ /// <summary>
+ /// Parse a Handshake protocol header from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode a handshake header. Otherwise false</returns>
+ public static bool Parse(out Handshake header, ByteSpan span)
+ {
+ header = new Handshake();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ header.MessageType = (HandshakeType)span[0];
+ header.Length = span.ReadBigEndian24(1);
+ header.MessageSequence = span.ReadBigEndian16(4);
+ header.FragmentOffset = span.ReadBigEndian24(6);
+ header.FragmentLength = span.ReadBigEndian24(9);
+ return true;
+ }
+
+ /// <summary>
+ /// Encode the Handshake protocol header to wire format
+ /// </summary>
+ /// <param name="span"></param>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.MessageType;
+ span.WriteBigEndian24(this.Length, 1);
+ span.WriteBigEndian16(this.MessageSequence, 4);
+ span.WriteBigEndian24(this.FragmentOffset, 6);
+ span.WriteBigEndian24(this.FragmentLength, 9);
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode ClientHello Handshake message
+ /// </summary>
+ public struct ClientHello
+ {
+ public ProtocolVersion ClientProtocolVersion;
+ public ByteSpan Random;
+ public ByteSpan Cookie;
+ public HazelDtlsSessionInfo Session;
+ public ByteSpan CipherSuites;
+ public ByteSpan SupportedCurves;
+
+ public const int MinSize = 0
+ + 2 // client_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 1 // cookie (size)
+ + 2 // cipher_suites (size)
+ + 1 // compression_methods (size)
+ + 1 // compression_method[0] (NULL)
+
+ + 2 // extensions size
+
+ + 0 // NamedCurveList extensions[0]
+ + 2 // extensions[0].extension_type
+ + 2 // extensions[0].extension_data (length)
+ + 2 // extensions[0].named_curve_list (size)
+ ;
+
+ /// <summary>
+ /// Calculate the size in bytes required for the ClientHello payload
+ /// </summary>
+ /// <returns></returns>
+ public int CalculateSize()
+ {
+ return MinSize
+ + this.Session.PayloadSize
+ + this.Cookie.Length
+ + this.CipherSuites.Length
+ + this.SupportedCurves.Length
+ ;
+ }
+
+ /// <summary>
+ /// Parse a Handshake ClientHello payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the ClientHello message. Otherwise false</returns>
+ public static bool Parse(out ClientHello result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new ClientHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ClientProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ if (expectedProtocolVersion.HasValue && result.ClientProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ byte cookieSize = span[0];
+ if (span.Length < 1 + cookieSize)
+ {
+ return false;
+ }
+ result.Cookie = span.Slice(1, cookieSize);
+ span = span.Slice(1 + cookieSize);
+
+ ushort cipherSuiteSize = span.ReadBigEndian16();
+ if (span.Length < 2 + cipherSuiteSize)
+ {
+ return false;
+ }
+ else if (cipherSuiteSize % 2 != 0)
+ {
+ return false;
+ }
+ result.CipherSuites = span.Slice(2, cipherSuiteSize);
+ span = span.Slice(2 + cipherSuiteSize);
+
+ int compressionMethodsSize = span[0];
+ bool foundNullCompressionMethod = false;
+ for (int ii = 0; ii != compressionMethodsSize; ++ii)
+ {
+ if (span[1+ii] == (byte)CompressionMethod.Null)
+ {
+ foundNullCompressionMethod = true;
+ break;
+ }
+ }
+
+ if (!foundNullCompressionMethod
+ || span.Length < 1 + compressionMethodsSize)
+ {
+ return false;
+ }
+
+ span = span.Slice(1 + compressionMethodsSize);
+
+ // Parse extensions
+ if (span.Length > 0)
+ {
+ if (span.Length < 2)
+ {
+ return false;
+ }
+
+ ushort extensionsSize = span.ReadBigEndian16();
+ span = span.Slice(2);
+ if (span.Length != extensionsSize)
+ {
+ return false;
+ }
+
+ while (span.Length > 0)
+ {
+ // Parse extension header
+ if (span.Length < 4)
+ {
+ return false;
+ }
+
+ ExtensionType extensionType = (ExtensionType)span.ReadBigEndian16(0);
+ ushort extensionLength = span.ReadBigEndian16(2);
+
+ if (span.Length < 4 + extensionLength)
+ {
+ return false;
+ }
+
+ ByteSpan extensionData = span.Slice(4, extensionLength);
+ span = span.Slice(4 + extensionLength);
+ result.ParseExtension(extensionType, extensionData);
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Decode a ClientHello extension
+ /// </summary>
+ /// <param name="extensionType">Extension type</param>
+ /// <param name="extensionData">Extension data</param>
+ private void ParseExtension(ExtensionType extensionType, ByteSpan extensionData)
+ {
+ switch (extensionType)
+ {
+ case ExtensionType.EllipticCurves:
+ if (extensionData.Length % 2 != 0)
+ {
+ break;
+ }
+ else if (extensionData.Length < 2)
+ {
+ break;
+ }
+
+ ushort namedCurveSize = extensionData.ReadBigEndian16(0);
+ if (namedCurveSize % 2 != 0)
+ {
+ break;
+ }
+
+ this.SupportedCurves = extensionData.Slice(2, namedCurveSize);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified cipher suite
+ /// </summary>
+ public bool ContainsCipherSuite(CipherSuite cipherSuite)
+ {
+ ByteSpan iterator = this.CipherSuites;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)cipherSuite)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified curve
+ /// </summary>
+ public bool ContainsCurve(NamedCurve curve)
+ {
+ ByteSpan iterator = this.SupportedCurves;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)curve)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Encode Handshake ClientHello payload to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span.WriteBigEndian16((ushort)this.ClientProtocolVersion);
+ span = span.Slice(2);
+
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span[0] = (byte)this.Cookie.Length;
+ this.Cookie.CopyTo(span.Slice(1));
+ span = span.Slice(1 + this.Cookie.Length);
+
+ span.WriteBigEndian16((ushort)this.CipherSuites.Length);
+ this.CipherSuites.CopyTo(span.Slice(2));
+ span = span.Slice(2 + this.CipherSuites.Length);
+
+ span[0] = 1;
+ span[1] = (byte)CompressionMethod.Null;
+ span = span.Slice(2);
+
+ // Extensions size
+ span.WriteBigEndian16((ushort)(6 + this.SupportedCurves.Length));
+ span = span.Slice(2);
+
+ // Supported curves extension
+ span.WriteBigEndian16((ushort)ExtensionType.EllipticCurves);
+ span.WriteBigEndian16((ushort)(2 + this.SupportedCurves.Length), 2);
+ span.WriteBigEndian16((ushort)this.SupportedCurves.Length, 4);
+ this.SupportedCurves.CopyTo(span.Slice(6));
+ }
+ }
+
+ /// <summary>
+ /// Encode/Decode session information in ClientHello
+ /// </summary>
+ public struct HazelDtlsSessionInfo
+ {
+ public const byte CurrentClientSessionSize = 1;
+ public const byte CurrentClientSessionVersion = 1;
+
+ public byte FullSize => (byte)(1 + this.PayloadSize);
+ public byte PayloadSize;
+ public byte Version;
+
+ public HazelDtlsSessionInfo(byte version)
+ {
+ this.Version = version;
+ switch (version)
+ {
+ case 0: // Does not write version byte
+ this.PayloadSize = 0;
+ return;
+ case 1: // Writes version byte only
+ this.PayloadSize = 1;
+ return;
+ }
+
+ throw new ArgumentOutOfRangeException("Unimplemented Hazel session version");
+ }
+
+ public void Encode(ByteSpan writer)
+ {
+ writer[0] = this.PayloadSize;
+
+ if (this.Version > 0)
+ {
+ writer[1] = this.Version;
+ }
+ }
+
+ public static bool Parse(out HazelDtlsSessionInfo result, ByteSpan reader)
+ {
+ result = new HazelDtlsSessionInfo();
+ if (reader.Length < 1)
+ {
+ return false;
+ }
+
+ result.PayloadSize = reader[0];
+
+ // Back compat, length may be zero, version defaults to 0.
+ if (result.PayloadSize == 0)
+ {
+ result.Version = 0;
+ return true;
+ }
+
+ // Forward compat, if length > 1, ignore the rest
+ result.Version = reader[1];
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake HelloVerifyRequest message
+ /// </summary>
+ public struct HelloVerifyRequest
+ {
+ public const int CookieSize = 20;
+ public const int Size = 0
+ + 2 // server_version
+ + 1 // cookie (size)
+ + CookieSize // cookie (data)
+ ;
+
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Cookie;
+
+ /// <summary>
+ /// Parse a Handshake HelloVerifyRequest payload from wire
+ /// format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the HelloVerifyRequest
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out HelloVerifyRequest result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new HelloVerifyRequest();
+ if (span.Length < 3)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(0);
+ if (expectedProtocolVersion.HasValue && result.ServerProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ byte cookieSize = span[2];
+ span = span.Slice(3);
+
+ if (span.Length < cookieSize)
+ {
+ return false;
+ }
+
+ result.Cookie = span;
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a HelloVerifyRequest payload to wire format
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static void Encode(ByteSpan span, EndPoint peerAddress, HMAC hmac, ProtocolVersion protocolVersion)
+ {
+ ByteSpan cookie = ComputeAddressMac(peerAddress, hmac);
+
+ span.WriteBigEndian16((ushort)protocolVersion);
+ span[2] = (byte)CookieSize;
+ cookie.CopyTo(span.Slice(3));
+ }
+
+ /// <summary>
+ /// Generate an HMAC for a peer address
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static ByteSpan ComputeAddressMac(EndPoint peerAddress, HMAC hmac)
+ {
+ SocketAddress address = peerAddress.Serialize();
+ byte[] data = new byte[address.Size];
+ for (int ii = 0, nn = data.Length; ii != nn; ++ii)
+ {
+ data[ii] = address[ii];
+ }
+
+ ///NOTE(mendsley): Lame that we need to allocate+copy here
+ ByteSpan signature = hmac.ComputeHash(data);
+ return signature.Slice(0, CookieSize);
+ }
+
+ /// <summary>
+ /// Verify a client's cookie was signed by our listener
+ /// </summary>
+ /// <param name="cookie">Wire format cookie</param>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ /// <returns>True if the cookie is valid. Otherwise false</returns>
+ public static bool VerifyCookie(ByteSpan cookie, EndPoint peerAddress, HMAC hmac)
+ {
+ if (cookie.Length != CookieSize)
+ {
+ return false;
+ }
+
+ ByteSpan expectedHash = ComputeAddressMac(peerAddress, hmac);
+ if (expectedHash.Length != cookie.Length)
+ {
+ return false;
+ }
+
+ return (1 == Crypto.Const.ConstantCompareSpans(cookie, expectedHash));
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake ServerHello message
+ /// </summary>
+ public struct ServerHello
+ {
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Random;
+ public CipherSuite CipherSuite;
+ public HazelDtlsSessionInfo Session;
+
+ public const int MinSize = 0
+ + 2 // server_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 2 // cipher_suite
+ + 1 // compression_method
+ ;
+
+ public int Size => MinSize + Session.PayloadSize;
+
+ /// <summary>
+ /// Parse a Handshake ServerHello payload from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the ServerHello
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out ServerHello result, ByteSpan span)
+ {
+ result = new ServerHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ result.CipherSuite = (CipherSuite)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ CompressionMethod compressionMethod = (CompressionMethod)span[0];
+ if (compressionMethod != CompressionMethod.Null)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode Handshake ServerHello to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+
+ span.WriteBigEndian16((ushort)this.ServerProtocolVersion, 0);
+ span = span.Slice(2);
+
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span.WriteBigEndian16((ushort)this.CipherSuite);
+ span = span.Slice(2);
+
+ span[0] = (byte)CompressionMethod.Null;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Certificate message
+ /// </summary>
+ public struct Certificate
+ {
+ /// <summary>
+ /// Encode a certificate to wire formate
+ /// </summary>
+ public static ByteSpan Encode(X509Certificate2 certificate)
+ {
+ ByteSpan certData = certificate.GetRawCertData();
+ int totalSize = certData.Length + 3 + 3;
+
+ ByteSpan result = new byte[totalSize];
+
+ ByteSpan writer = result;
+ writer.WriteBigEndian24((uint)certData.Length + 3);
+ writer = writer.Slice(3);
+ writer.WriteBigEndian24((uint)certData.Length);
+ writer = writer.Slice(3);
+
+ certData.CopyTo(writer);
+ return result;
+ }
+
+ /// <summary>
+ /// Parse a Handshake Certificate payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the Certificate message. Otherwise false</returns>
+ public static bool Parse(out X509Certificate2 certificate, ByteSpan span)
+ {
+ certificate = null;
+ if (span.Length < 6)
+ {
+ return false;
+ }
+
+ uint totalSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+
+ if (span.Length < totalSize)
+ {
+ return false;
+ }
+
+ uint certificateSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+ if (span.Length < certificateSize)
+ {
+ return false;
+ }
+
+ byte[] rawData = new byte[certificateSize];
+ span.CopyTo(rawData, 0);
+ try
+ {
+ certificate = new X509Certificate2(rawData);
+ }
+ catch (Exception)
+ {
+ return false;
+ }
+
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Finished message
+ /// </summary>
+ public struct Finished
+ {
+ public const int Size = 12;
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
new file mode 100644
index 0000000..eedd977
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
@@ -0,0 +1,63 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for the handshake portion of
+ /// the connection.
+ /// </summary>
+ public interface IHandshakeCipherSuite : IDisposable
+ {
+ /// <summary>
+ /// Gets the size of the shared key
+ /// </summary>
+ /// <returns>Size of the shared key in bytes </returns>
+ int SharedKeySize();
+
+ /// <summary>
+ /// Calculate the size of the ServerKeyExchnage message
+ /// </summary>
+ /// <param name="privateKey">
+ /// Private key that will be used to sign the message
+ /// </param>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateServerMessageSize(object privateKey);
+
+ /// <summary>
+ /// Encodes the ServerKeyExchange message
+ /// </summary>
+ /// <param name="privateKey">Private key to use for signing</param>
+ void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey);
+
+ /// <summary>
+ /// Verifies the authenticity of a server key exchange
+ /// message and calculates the shared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the authenticity has been validated and a shared key
+ /// was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey);
+
+ /// <summary>
+ /// Calculate the size of the ClientKeyExchange message
+ /// </summary>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateClientMessageSize();
+
+ /// <summary>
+ /// Encodes the ClientKeyExchangeMessage
+ /// </summary>
+ void EncodeClientKeyExchangeMessage(ByteSpan output);
+
+ /// <summary>
+ /// Verifies the validity of a client key exchange message
+ /// and calculats the hsared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the client exchange message is valid and a
+ /// shared key was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage);
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
new file mode 100644
index 0000000..cbee1b0
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
@@ -0,0 +1,84 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for protection of record payload.
+ /// </summary>
+ public interface IRecordProtection : IDisposable
+ {
+ /// <summary>
+ /// Calculate the size of an encrypted plaintext
+ /// </summary>
+ /// <param name="dataSize">Size of plaintext in bytes</param>
+ /// <returns>Size of encrypted ciphertext in bytes</returns>
+ int GetEncryptedSize(int dataSize);
+
+ /// <summary>
+ /// Calculate the size of decrypted ciphertext
+ /// </summary>
+ /// <param name="dataSize">Size of ciphertext in bytes</param>
+ /// <returns>Size of decrypted plaintext in bytes</returns>
+ int GetDecryptedSize(int dataSize);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record);
+ }
+
+ /// <summary>
+ /// Factory to create record protection from cipher suite identifiers
+ /// </summary>
+ public sealed class RecordProtectionFactory
+ {
+ public static IRecordProtection Create(CipherSuite cipherSuite, ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ switch (cipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ return new Aes128GcmRecordProtection(masterSecret, serverRandom, clientRandom);
+
+ default:
+ return null;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
new file mode 100644
index 0000000..76fa132
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Passthrough record protection implementaion
+ /// </summary>
+ public class NullRecordProtection : IRecordProtection
+ {
+ public readonly static NullRecordProtection Instance = new NullRecordProtection();
+
+ public void Dispose()
+ {
+ }
+
+ public int GetEncryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public int GetDecryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ private static void CopyMaybeOverlappingSpans(ByteSpan output, ByteSpan input)
+ {
+ // Early out if the ranges `output` is equal to `input`
+ if (output.GetUnderlyingArray() == input.GetUnderlyingArray())
+ {
+ if (output.Offset == input.Offset && output.Length == input.Length)
+ {
+ return;
+ }
+ }
+
+ input.CopyTo(output);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
new file mode 100644
index 0000000..1fa7f17
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
@@ -0,0 +1,84 @@
+using System.Text;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Common Psuedorandom Function labels for TLS
+ /// </summary>
+ public struct PrfLabel
+ {
+ public static readonly ByteSpan MASTER_SECRET = LabelToBytes("master secert");
+ public static readonly ByteSpan KEY_EXPANSION = LabelToBytes("key expansion");
+ public static readonly ByteSpan CLIENT_FINISHED = LabelToBytes("client finished");
+ public static readonly ByteSpan SERVER_FINISHED = LabelToBytes("server finished");
+
+ /// <summary>
+ /// Convert a text label to a byte sequence
+ /// </summary>
+ public static ByteSpan LabelToBytes(string label)
+ {
+ return Encoding.ASCII.GetBytes(label);
+ }
+ }
+
+ /// <summary>
+ /// The P_SHA256 Psuedorandom Function
+ /// </summary>
+ public struct PrfSha256
+ {
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, string label, ByteSpan initialSeed)
+ {
+ ExpandSecret(output, key, PrfLabel.LabelToBytes(label), initialSeed);
+ }
+
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, ByteSpan label, ByteSpan initialSeed)
+ {
+ ByteSpan writer = output;
+
+ byte[] roundSeed = new byte[label.Length + initialSeed.Length];
+ label.CopyTo(roundSeed);
+ initialSeed.CopyTo(roundSeed, label.Length);
+
+ byte[] hashA = roundSeed;
+
+ using (HMACSHA256 hmac = new HMACSHA256(key.ToArray()))
+ {
+ byte[] input = new byte[hmac.OutputBlockSize + roundSeed.Length];
+ new ByteSpan(roundSeed).CopyTo(input, hmac.OutputBlockSize);
+
+ while (writer.Length > 0)
+ {
+ // Update hashA
+ hashA = hmac.ComputeHash(hashA);
+
+ // generate hash input
+ new ByteSpan(hashA).CopyTo(input);
+
+ ByteSpan roundOutput = hmac.ComputeHash(input);
+ if (roundOutput.Length > writer.Length)
+ {
+ roundOutput = roundOutput.Slice(0, writer.Length);
+ }
+
+ roundOutput.CopyTo(writer);
+ writer = writer.Slice(roundOutput.Length);
+ }
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Record.cs b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
new file mode 100644
index 0000000..23eaa95
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
@@ -0,0 +1,123 @@
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS version constants
+ /// </summary>
+ public enum ProtocolVersion : ushort
+ {
+ /// <summary>
+ /// Use to obfuscate DTLS as regular UDP packets
+ /// </summary>
+ UDP = 0,
+
+ /// <summary>
+ /// DTLS 1.2
+ /// </summary>
+ DTLS1_2 = 0xFEFD,
+ }
+
+ /// <summary>
+ /// DTLS record content type
+ /// </summary>
+ public enum ContentType : byte
+ {
+ ChangeCipherSpec = 20,
+ Alert = 21,
+ Handshake = 22,
+ ApplicationData = 23,
+ }
+
+ /// <summary>
+ /// Encode/decode DTLS record header
+ /// </summary>
+ public struct Record
+ {
+ public ContentType ContentType;
+ public ProtocolVersion ProtocolVersion;
+ public ushort Epoch;
+ public ulong SequenceNumber;
+ public ushort Length;
+
+ public const int Size = 13;
+
+ /// <summary>
+ /// Parse a DTLS record from wire format
+ /// </summary>
+ /// <returns>True if we successfully parse the record header. Otherwise false</returns>
+ public static bool Parse(out Record record, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ record = new Record();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ record.ContentType = (ContentType)span[0];
+ record.ProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(1);
+ record.Epoch = span.ReadBigEndian16(3);
+ record.SequenceNumber = span.ReadBigEndian48(5);
+ record.Length = span.ReadBigEndian16(11);
+
+ if (expectedProtocolVersion.HasValue && record.ProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a DTLS record to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.ContentType;
+ span.WriteBigEndian16((ushort)this.ProtocolVersion, 1);
+ span.WriteBigEndian16(this.Epoch, 3);
+ span.WriteBigEndian48(this.SequenceNumber, 5);
+ span.WriteBigEndian16(this.Length, 11);
+ }
+ }
+
+ public struct ChangeCipherSpec
+ {
+ public const int Size = 1;
+
+ enum Value : byte
+ {
+ ChangeCipherSpec = 1,
+ }
+
+ /// <summary>
+ /// Parse a ChangeCipherSpec record from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully parse the ChangeCipherSpec
+ /// record. Otherwise, false.
+ /// </returns>
+ public static bool Parse(ByteSpan span)
+ {
+ if (span.Length != 1)
+ {
+ return false;
+ }
+
+ Value value = (Value)span[0];
+ if (value != Value.ChangeCipherSpec)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a ChangeCipherSpec record to wire format
+ /// </summary>
+ public static void Encode(ByteSpan span)
+ {
+ span[0] = (byte)Value.ChangeCipherSpec;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
new file mode 100644
index 0000000..38da061
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
@@ -0,0 +1,159 @@
+using System;
+using System.Collections.Concurrent;
+using System.Security.Cryptography;
+using System.Threading;
+
+namespace Hazel.Dtls
+{
+ internal class ThreadedHmacHelper : IDisposable
+ {
+ private class ThreadHmacs
+ {
+ public HMAC currentHmac;
+ public HMAC previousHmac;
+ public HMAC hmacToDispose;
+ }
+
+ private static readonly int CookieHmacRotationTimeout = (int)TimeSpan.FromHours(1.0).TotalMilliseconds;
+
+ private readonly ILogger logger;
+ private readonly ConcurrentDictionary<int, ThreadHmacs> hmacs;
+ private Timer rotateKeyTimer;
+ private RandomNumberGenerator cryptoRandom;
+ private byte[] currentHmacKey;
+
+ public ThreadedHmacHelper(ILogger logger)
+ {
+ this.hmacs = new ConcurrentDictionary<int, ThreadHmacs>();
+ this.rotateKeyTimer = new Timer(RotateKeys, null, CookieHmacRotationTimeout, CookieHmacRotationTimeout);
+ this.cryptoRandom = RandomNumberGenerator.Create();
+
+ this.logger = logger;
+ SetHmacKey();
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the current cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetCurrentCookieHmacsForThread()
+ {
+ return GetHmacsForThread().currentHmac;
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the previous cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetPreviousCookieHmacsForThread()
+ {
+ return GetHmacsForThread().previousHmac;
+ }
+
+ public void Dispose()
+ {
+ ManualResetEvent signalRotateKeyTimerEnded = new ManualResetEvent(false);
+ this.rotateKeyTimer.Dispose(signalRotateKeyTimerEnded);
+ signalRotateKeyTimerEnded.WaitOne();
+ signalRotateKeyTimerEnded.Dispose();
+ signalRotateKeyTimerEnded = null;
+ this.rotateKeyTimer = null;
+
+ this.cryptoRandom.Dispose();
+ this.cryptoRandom = null;
+
+ foreach (var threadIdToHmac in this.hmacs)
+ {
+ ThreadHmacs threadHmacs = threadIdToHmac.Value;
+ threadHmacs.currentHmac?.Dispose();
+ threadHmacs.currentHmac = null;
+ threadHmacs.previousHmac?.Dispose();
+ threadHmacs.previousHmac = null;
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = null;
+ }
+
+ this.hmacs.Clear();
+ }
+
+ private ThreadHmacs GetHmacsForThread()
+ {
+ int threadId = Thread.CurrentThread.ManagedThreadId;
+
+ if (!this.hmacs.TryGetValue(threadId, out ThreadHmacs threadHmacs))
+ {
+ threadHmacs = CreateNewThreadHmacs();
+
+ if (!this.hmacs.TryAdd(threadId, threadHmacs))
+ {
+ this.logger.WriteError($"Cannot add threadHmacs for thread {threadId} during GetHmacsForThread! Should never happen!");
+ }
+ }
+
+ return threadHmacs;
+ }
+
+ /// <summary>
+ /// Rotates the hmacs of all active threads
+ /// </summary>
+ private void RotateKeys(object _)
+ {
+ SetHmacKey();
+
+ foreach (var threadIds in this.hmacs)
+ {
+ RotateKey(threadIds.Key);
+ }
+ }
+
+ /// <summary>
+ /// Rotate hmacs of single thread
+ /// </summary>
+ /// <param name="threadId">Managed thread Id of thread calling this method.</param>
+ private void RotateKey(int threadId)
+ {
+ ThreadHmacs threadHmacs;
+
+ if (!this.hmacs.TryGetValue(threadId, out threadHmacs))
+ {
+ this.logger.WriteError($"Cannot find thread {threadId} in hmacs during rotation! Should never happen!");
+ return;
+ }
+
+ // No thread should still have a reference to hmacToDispose, which should now have a lifetime of > 1 hour
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = threadHmacs.previousHmac;
+ threadHmacs.previousHmac = threadHmacs.currentHmac;
+ threadHmacs.currentHmac = CreateNewCookieHMAC();
+ }
+
+ private ThreadHmacs CreateNewThreadHmacs()
+ {
+ return new ThreadHmacs
+ {
+ previousHmac = CreateNewCookieHMAC(),
+ currentHmac = CreateNewCookieHMAC()
+ };
+ }
+
+ /// <summary>
+ /// Create a new cookie HMAC signer
+ /// </summary>
+ private HMAC CreateNewCookieHMAC()
+ {
+ const string HMACProvider = "System.Security.Cryptography.HMACSHA1";
+ HMAC hmac = HMAC.Create(HMACProvider);
+ hmac.Key = this.currentHmacKey;
+ return hmac;
+ }
+
+ /// <summary>
+ /// Creates a new cryptographically secure random Hmac key
+ /// </summary>
+ private void SetHmacKey()
+ {
+ // MSDN recommends 64 bytes key for HMACSHA-1
+ byte[] newKey = new byte[64];
+ this.cryptoRandom.GetBytes(newKey);
+ this.currentHmacKey = newKey;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
new file mode 100644
index 0000000..f567252
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
@@ -0,0 +1,202 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// ECDHE_RSA_*_256 cipher suite
+ /// </summary>
+ public class X25519EcdheRsaSha256 : IHandshakeCipherSuite
+ {
+ private readonly ByteSpan privateAgreementKey;
+ private SHA256 sha256 = SHA256.Create();
+
+ /// <summary>
+ /// Create a new instance of the x25519 key exchange
+ /// </summary>
+ /// <param name="random">Random data source</param>
+ public X25519EcdheRsaSha256(RandomNumberGenerator random)
+ {
+ byte[] buffer = new byte[X25519.KeySize];
+ random.GetBytes(buffer);
+ this.privateAgreementKey = buffer;
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.sha256?.Dispose();
+ this.sha256 = null;
+ }
+
+ /// <inheritdoc />
+ public int SharedKeySize()
+ {
+ return X25519.KeySize;
+ }
+
+ /// <summary>
+ /// Calculate the server message size given an RSA key size
+ /// </summary>
+ /// <param name="keySize">
+ /// Size of the private key (in bits)
+ /// </param>
+ /// <returns>
+ /// Size of the ServerKeyExchange message in bytes
+ /// </returns>
+ private static int CalculateServerMessageSize(int keySize)
+ {
+ int signatureSize = keySize / 8;
+
+ return 0
+ + 1 // ECCurveType ServerKeyExchange.params.curve_params.curve_type
+ + 2 // NamedCurve ServerKeyExchange.params.curve_params.namedcurve
+ + 1 + X25519.KeySize // ECPoint ServerKeyExchange.params.public
+ + 1 // HashAlgorithm ServerKeyExchange.algorithm.hash
+ + 1 // SignatureAlgorithm ServerKeyExchange.signed_params.algorithm.signature
+ + 2 // ServerKeyExchange.signed_params.size
+ + signatureSize // ServerKeyExchange.signed_params.opaque
+ ;
+ }
+
+ /// <inheritdoc />
+ public int CalculateServerMessageSize(object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ return CalculateServerMessageSize(rsaPrivateKey.KeySize);
+ }
+
+ /// <inheritdoc />
+ public void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ output[0] = (byte)ECCurveType.NamedCurve;
+ output.WriteBigEndian16((ushort)NamedCurve.x25519, 1);
+ output[3] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(4, X25519.KeySize), this.privateAgreementKey);
+
+ // Hash the key parameters
+ byte[] paramterDigest = this.sha256.ComputeHash(output.GetUnderlyingArray(), output.Offset, 4 + X25519.KeySize);
+
+ // Sign the paramter digest
+ RSAPKCS1SignatureFormatter signer = new RSAPKCS1SignatureFormatter(rsaPrivateKey);
+ signer.SetHashAlgorithm("SHA256");
+ ByteSpan signature = signer.CreateSignature(paramterDigest);
+
+ Debug.Assert(signature.Length == rsaPrivateKey.KeySize/8);
+ output[4 + X25519.KeySize] = (byte)HashAlgorithm.Sha256;
+ output[5 + X25519.KeySize] = (byte)SignatureAlgorithm.RSA;
+ output.Slice(6+X25519.KeySize).WriteBigEndian16((ushort)signature.Length);
+ signature.CopyTo(output.Slice(8+X25519.KeySize));
+ }
+
+ /// <inheritdoc />
+ public bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey)
+ {
+ RSA rsaPublicKey = publicKey as RSA;
+ if (rsaPublicKey == null)
+ {
+ return false;
+ }
+ else if (output.Length != X25519.KeySize)
+ {
+ return false;
+ }
+
+ // Verify message is compatible with this cipher suite
+ if (serverKeyExchangeMessage.Length != CalculateServerMessageSize(rsaPublicKey.KeySize))
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[0] != (byte)ECCurveType.NamedCurve)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage.ReadBigEndian16(1) != (ushort)NamedCurve.x25519)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[3] != X25519.KeySize)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[4 + X25519.KeySize] != (byte)HashAlgorithm.Sha256)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[5 + X25519.KeySize] != (byte)SignatureAlgorithm.RSA)
+ {
+ return false;
+ }
+
+ ByteSpan keyParameters = serverKeyExchangeMessage.Slice(0, 4+X25519.KeySize);
+ ByteSpan othersPublicKey = keyParameters.Slice(4);
+ ushort signatureSize = serverKeyExchangeMessage.ReadBigEndian16(6 + X25519.KeySize);
+ ByteSpan signature = serverKeyExchangeMessage.Slice(4+keyParameters.Length);
+
+ if (signatureSize != signature.Length)
+ {
+ return false;
+ }
+
+ // Hash the key parameters
+ byte[] parameterDigest = this.sha256.ComputeHash(keyParameters.GetUnderlyingArray(), keyParameters.Offset, keyParameters.Length);
+
+ // Verify the signature
+ RSAPKCS1SignatureDeformatter verifier = new RSAPKCS1SignatureDeformatter(rsaPublicKey);
+ verifier.SetHashAlgorithm("SHA256");
+ if (!verifier.VerifySignature(parameterDigest, signature.ToArray()))
+ {
+ return false;
+ }
+
+ // Signature has been validated, generate the shared key
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+
+ private static int ClientMessageSize = 0
+ + 1 + X25519.KeySize // ECPoint ClientKeyExchange.ecdh_Yc
+ ;
+
+ /// <inheritdoc />
+ public int CalculateClientMessageSize()
+ {
+ return ClientMessageSize;
+ }
+
+ /// <inheritdoc />
+ public void EncodeClientKeyExchangeMessage(ByteSpan output)
+ {
+ output[0] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(1, X25519.KeySize), this.privateAgreementKey);
+ }
+
+ /// <inheritdoc />
+ public bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage)
+ {
+ if (clientKeyExchangeMessage.Length != ClientMessageSize)
+ {
+ return false;
+ }
+ else if (clientKeyExchangeMessage[0] != (byte)X25519.KeySize)
+ {
+ return false;
+ }
+
+ ByteSpan othersPublicKey = clientKeyExchangeMessage.Slice(1);
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Extensions.cs b/Tools/Hazel-Networking/Hazel/Extensions.cs
new file mode 100644
index 0000000..dd3a1bc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Extensions.cs
@@ -0,0 +1,34 @@
+using System.Collections.Generic;
+
+namespace Hazel
+{
+ public static class Extensions
+ {
+ public static void Swap<T>(this IList<T> self, int idx0, int idx1)
+ {
+ var temp = self[idx0];
+ self[idx0] = self[idx1];
+ self[idx1] = temp;
+ }
+
+ public static int ClampToInt(this float value, int min, int max)
+ {
+ int output = (int)value;
+ if (output < min) output = min;
+ else if (output > max) output = max;
+ return output;
+ }
+
+ public static bool TryDequeue<T>(this Queue<T> self, out T item)
+ {
+ if (self.Count > 0)
+ {
+ item = self.Dequeue();
+ return true;
+ }
+
+ item = default;
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs
new file mode 100644
index 0000000..fb36b00
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs
@@ -0,0 +1,44 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Hazel
+{
+ internal class HazelThreadPool
+ {
+ private Thread[] threads;
+
+ public HazelThreadPool(int numThreads, ThreadStart action)
+ {
+ this.threads = new Thread[numThreads];
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ this.threads[i] = new Thread(action);
+ }
+ }
+
+ public void Start()
+ {
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ this.threads[i].Start();
+ }
+ }
+
+ public void Join()
+ {
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ var thread = this.threads[i];
+ try
+ {
+ thread.Join();
+ }
+ catch { }
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs
new file mode 100644
index 0000000..e37be45
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs
@@ -0,0 +1,402 @@
+using System;
+using System.Collections.Concurrent;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.Udp.FewerThreads
+{
+ /// <summary>
+ /// Listens for new UDP connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class ThreadLimitedUdpConnectionListener : NetworkConnectionListener
+ {
+ private struct SendMessageInfo
+ {
+ public ByteSpan Span;
+ public IPEndPoint Recipient;
+ }
+
+ private struct ReceiveMessageInfo
+ {
+ public MessageReader Message;
+ public IPEndPoint Sender;
+ public ConnectionId ConnectionId;
+ }
+
+ private const int SendReceiveBufferSize = 1024 * 1024;
+
+ private Socket socket;
+ protected ILogger Logger;
+
+ private Thread reliablePacketThread;
+ private Thread receiveThread;
+ private Thread sendThread;
+ private HazelThreadPool processThreads;
+
+ public bool ReceiveThreadRunning => this.receiveThread.ThreadState == ThreadState.Running;
+
+ public struct ConnectionId : IEquatable<ConnectionId>
+ {
+ public IPEndPoint EndPoint;
+ public int Serial;
+
+ public static ConnectionId Create(IPEndPoint endPoint, int serial)
+ {
+ return new ConnectionId{
+ EndPoint = endPoint,
+ Serial = serial,
+ };
+ }
+
+ public bool Equals(ConnectionId other)
+ {
+ return this.Serial == other.Serial
+ && this.EndPoint.Equals(other.EndPoint)
+ ;
+ }
+
+ public override bool Equals(object obj)
+ {
+ if (obj is ConnectionId)
+ {
+ return this.Equals((ConnectionId)obj);
+ }
+
+ return false;
+ }
+
+ public override int GetHashCode()
+ {
+ ///NOTE(mendsley): We're only hashing the endpoint
+ /// here, as the common case will have one
+ /// connection per address+port tuple.
+ return this.EndPoint.GetHashCode();
+ }
+ }
+
+ protected ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection> allConnections = new ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection>();
+
+ private BlockingCollection<ReceiveMessageInfo> receiveQueue;
+ private BlockingCollection<SendMessageInfo> sendQueue = new BlockingCollection<SendMessageInfo>();
+
+ public int MaxAge
+ {
+ get
+ {
+ var now = DateTime.UtcNow;
+ TimeSpan max = new TimeSpan();
+ foreach (var con in allConnections.Values)
+ {
+ var val = now - con.CreationTime;
+ if (val > max) max = val;
+ }
+
+ return (int)max.TotalSeconds;
+ }
+ }
+
+ public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count;
+ public override int ConnectionCount { get { return this.allConnections.Count; } }
+ public override int SendQueueLength { get { return this.sendQueue.Count; } }
+ public override int ReceiveQueueLength { get { return this.receiveQueue.Count; } }
+
+ private bool isActive;
+
+ public ThreadLimitedUdpConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ this.Logger = logger;
+ this.EndPoint = endPoint;
+ this.IPMode = ipMode;
+
+ this.receiveQueue = new BlockingCollection<ReceiveMessageInfo>(10000);
+
+ this.socket = UdpConnection.CreateSocket(this.IPMode);
+ this.socket.ExclusiveAddressUse = true;
+ this.socket.Blocking = false;
+
+ this.socket.ReceiveBufferSize = SendReceiveBufferSize;
+ this.socket.SendBufferSize = SendReceiveBufferSize;
+
+ this.reliablePacketThread = new Thread(ManageReliablePackets);
+ this.sendThread = new Thread(SendLoop);
+ this.receiveThread = new Thread(ReceiveLoop);
+ this.processThreads = new HazelThreadPool(numWorkers, ProcessingLoop);
+ }
+
+ ~ThreadLimitedUdpConnectionListener()
+ {
+ this.Dispose(false);
+ }
+
+ // This is just for booting people after they've been connected a certain amount of time...
+ public void DisconnectOldConnections(TimeSpan maxAge, MessageWriter disconnectMessage)
+ {
+ var now = DateTime.UtcNow;
+ foreach (var conn in this.allConnections.Values)
+ {
+ if (now - conn.CreationTime > maxAge)
+ {
+ conn.Disconnect("Stale Connection", disconnectMessage);
+ }
+ }
+ }
+
+ private void ManageReliablePackets()
+ {
+ while (this.isActive)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ var sock = kvp.Value;
+ sock.ManageReliablePackets();
+ }
+
+ Thread.Sleep(100);
+ }
+ }
+
+ public override void Start()
+ {
+ try
+ {
+ socket.Bind(EndPoint);
+ }
+ catch (SocketException e)
+ {
+ throw new HazelException("Could not start listening as a SocketException occurred", e);
+ }
+
+ this.isActive = true;
+ this.reliablePacketThread.Start();
+ this.sendThread.Start();
+ this.receiveThread.Start();
+ this.processThreads.Start();
+ }
+
+ private void ReceiveLoop()
+ {
+ while (this.isActive)
+ {
+ if (this.socket.Poll(1000, SelectMode.SelectRead))
+ {
+ if (!isActive) break;
+
+ EndPoint remoteEP = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port);
+ var message = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ message.Length = socket.ReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP);
+ }
+ catch (SocketException sx)
+ {
+ message.Recycle();
+ if (sx.SocketErrorCode == SocketError.NotConnected)
+ {
+ this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected);
+ return;
+ }
+
+ this.Logger.WriteError("Socket Ex in ReceiveLoop: " + sx.Message);
+ continue;
+ }
+ catch (Exception ex)
+ {
+ message.Recycle();
+ this.Logger.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+
+ ConnectionId connectionId = ConnectionId.Create((IPEndPoint)remoteEP, 0);
+ this.ProcessIncomingMessageFromOtherThread(message, (IPEndPoint)remoteEP, connectionId);
+ }
+ }
+ }
+
+ private void ProcessingLoop()
+ {
+ foreach (ReceiveMessageInfo msg in this.receiveQueue.GetConsumingEnumerable())
+ {
+ try
+ {
+ this.ReadCallback(msg.Message, msg.Sender, msg.ConnectionId);
+ }
+ catch
+ {
+
+ }
+ }
+ }
+
+ protected void ProcessIncomingMessageFromOtherThread(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId)
+ {
+ var info = new ReceiveMessageInfo() { Message = message, Sender = remoteEndPoint, ConnectionId = connectionId };
+ if (!this.receiveQueue.TryAdd(info))
+ {
+ this.Statistics.AddReceiveThreadBlocking();
+ this.receiveQueue.Add(info);
+ }
+ }
+
+ private void SendLoop()
+ {
+ foreach (SendMessageInfo msg in this.sendQueue.GetConsumingEnumerable())
+ {
+ try
+ {
+ if (this.socket.Poll(Timeout.Infinite, SelectMode.SelectWrite))
+ {
+ this.socket.SendTo(msg.Span.GetUnderlyingArray(), msg.Span.Offset, msg.Span.Length, SocketFlags.None, msg.Recipient);
+ this.Statistics.AddBytesSent(msg.Span.Length - msg.Span.Offset);
+ }
+ else
+ {
+ this.Logger.WriteError("Socket is no longer able to send");
+ break;
+ }
+ }
+ catch (Exception e)
+ {
+ this.Logger.WriteError("Error in loop while sending: " + e.Message);
+ Thread.Sleep(1);
+ }
+ }
+ }
+
+ protected virtual void ReadCallback(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId)
+ {
+ int bytesReceived = message.Length;
+ bool aware = true;
+ bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello;
+
+ // If we're aware of this connection use the one already
+ // If this is a new client then connect with them!
+ ThreadLimitedUdpServerConnection connection;
+ if (!this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ lock (this.allConnections)
+ {
+ if (!this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ // Check for malformed connection attempts
+ if (!isHello)
+ {
+ message.Recycle();
+ return;
+ }
+
+ if (AcceptConnection != null)
+ {
+ if (!AcceptConnection(remoteEndPoint, message.Buffer, out var response))
+ {
+ message.Recycle();
+ if (response != null)
+ {
+ SendDataRaw(response, remoteEndPoint);
+ }
+
+ return;
+ }
+ }
+
+ aware = false;
+ connection = new ThreadLimitedUdpServerConnection(this, connectionId, remoteEndPoint, this.IPMode, this.Logger);
+ if (!this.allConnections.TryAdd(connectionId, connection))
+ {
+ throw new HazelException("Failed to add a connection. This should never happen.");
+ }
+ }
+ }
+ }
+
+ // If it's a new connection invoke the NewConnection event.
+ // This needs to happen before handling the message because in localhost scenarios, the ACK and
+ // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers
+ if (!aware)
+ {
+ // Skip header and hello byte;
+ message.Offset = 4;
+ message.Length = bytesReceived - 4;
+ message.Position = 0;
+ try
+ {
+ this.InvokeNewConnection(message, connection);
+ }
+ catch (Exception e)
+ {
+ this.Logger.WriteError("NewConnection handler threw: " + e);
+ }
+ }
+
+ // Inform the connection of the buffer (new connections need to send an ack back to client)
+ connection.HandleReceive(message, bytesReceived);
+ }
+
+ internal void SendDataRaw(byte[] response, IPEndPoint remoteEndPoint)
+ {
+ QueueRawData(response, remoteEndPoint);
+ }
+
+ protected virtual void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint)
+ {
+ this.sendQueue.TryAdd(new SendMessageInfo() { Span = span, Recipient = remoteEndPoint });
+ }
+
+ /// <summary>
+ /// Removes a virtual connection from the list.
+ /// </summary>
+ /// <param name="endPoint">Connection key of the virtual connection.</param>
+ internal bool RemoveConnectionTo(ConnectionId connectionId)
+ {
+ return this.allConnections.TryRemove(connectionId, out _);
+ }
+
+ /// <summary>
+ /// This is after all messages could be sent. Clean up anything extra.
+ /// </summary>
+ internal virtual void RemovePeerRecord(ConnectionId connectionId)
+ {
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ kvp.Value.Dispose();
+ }
+
+ bool wasActive = this.isActive;
+ this.isActive = false;
+
+ // Flush outgoing packets
+ this.sendQueue?.CompleteAdding();
+
+ if (wasActive)
+ {
+ this.sendThread.Join();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.receiveQueue?.CompleteAdding();
+
+ if (wasActive)
+ {
+ this.reliablePacketThread.Join();
+ this.receiveThread.Join();
+ this.processThreads.Join();
+ }
+
+ this.receiveQueue?.Dispose();
+ this.receiveQueue = null;
+ this.sendQueue?.Dispose();
+ this.sendQueue = null;
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs
new file mode 100644
index 0000000..bb139c7
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs
@@ -0,0 +1,110 @@
+using System;
+using System.Net;
+
+namespace Hazel.Udp.FewerThreads
+{
+ /// <summary>
+ /// Represents a servers's connection to a client that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ public sealed class ThreadLimitedUdpServerConnection : UdpConnection
+ {
+ public readonly DateTime CreationTime = DateTime.UtcNow;
+
+ /// <summary>
+ /// The connection listener that we use the socket of.
+ /// </summary>
+ /// <remarks>
+ /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that
+ /// created this connection and is hence the listener this conenction sends and receives via.
+ /// </remarks>
+ public ThreadLimitedUdpConnectionListener Listener { get; private set; }
+
+ public ThreadLimitedUdpConnectionListener.ConnectionId ConnectionId { get; private set; }
+
+ /// <summary>
+ /// Creates a UdpConnection for the virtual connection to the endpoint.
+ /// </summary>
+ /// <param name="listener">The listener that created this connection.</param>
+ /// <param name="endPoint">The endpoint that we are connected to.</param>
+ /// <param name="IPMode">The IPMode we are connected using.</param>
+ internal ThreadLimitedUdpServerConnection(ThreadLimitedUdpConnectionListener listener, ThreadLimitedUdpConnectionListener.ConnectionId connectionId, IPEndPoint endPoint, IPMode IPMode, ILogger logger)
+ : base(logger)
+ {
+ this.Listener = listener;
+ this.ConnectionId = connectionId;
+ this.EndPoint = endPoint;
+ this.IPMode = IPMode;
+
+ State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ if (bytes.Length != length) throw new ArgumentException("I made an assumption here. I hope you see this error.");
+
+ // Hrm, well this is inaccurate for DTLS connections because the Listener does the encryption which may change the size.
+ // but I don't want to have a bunch of client references in the send queue...
+ // Does this perhaps mean the encryption is being done in the wrong class?
+ this.Statistics.LogPacketSend(length);
+ Listener.SendDataRaw(bytes, EndPoint);
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ if (!Listener.RemoveConnectionTo(this.ConnectionId)) return false;
+ this._state = ConnectionState.NotConnected;
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ this.WriteBytesToConnection(bytes, bytes.Length);
+ }
+ catch { }
+
+ return true;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ Listener.RemovePeerRecord(this.ConnectionId);
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Hazel.csproj b/Tools/Hazel-Networking/Hazel/Hazel.csproj
new file mode 100644
index 0000000..3a7ea17
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Hazel.csproj
@@ -0,0 +1,14 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <TargetFrameworks>netstandard2.0;net472</TargetFrameworks>
+ <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <AssemblyAttribute Include="System.Runtime.CompilerServices.InternalsVisibleToAttribute">
+ <_Parameter1>Hazel.UnitTests</_Parameter1>
+ </AssemblyAttribute>
+ </ItemGroup>
+
+</Project>
diff --git a/Tools/Hazel-Networking/Hazel/HazelException.cs b/Tools/Hazel-Networking/Hazel/HazelException.cs
new file mode 100644
index 0000000..c0db05a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/HazelException.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Wrapper for exceptions thrown from Hazel.
+ /// </summary>
+ [Serializable]
+ public class HazelException : Exception
+ {
+ internal HazelException(string msg) : base (msg)
+ {
+
+ }
+
+ internal HazelException(string msg, Exception e) : base (msg, e)
+ {
+
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/IPMode.cs b/Tools/Hazel-Networking/Hazel/IPMode.cs
new file mode 100644
index 0000000..04c8c38
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/IPMode.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Represents the IP version that a connection or listener will use.
+ /// </summary>
+ /// <remarks>
+ /// If you wand a client to connect or be able to connect using IPv6 then you should use <see cref="IPv4AndIPv6"/>,
+ /// this sets the underlying sockets to use IPv6 but still allow IPv4 sockets to connect for backwards compatability
+ /// and hence it is the default IPMode in most cases.
+ /// </remarks>
+ public enum IPMode
+ {
+ /// <summary>
+ /// Instruction to use IPv4 only, IPv6 connections will not be able to connect.
+ /// </summary>
+ IPv4,
+
+ /// <summary>
+ /// Instruction to use IPv6 only, IPv4 connections will not be able to connect. IPv4 addresses can be connected
+ /// by converting to IPv6 addresses.
+ /// </summary>
+ IPv6
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/IRecyclable.cs b/Tools/Hazel-Networking/Hazel/IRecyclable.cs
new file mode 100644
index 0000000..3e9769e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/IRecyclable.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Interface for all items that can be returned to an object pool.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public interface IRecyclable
+ {
+ /// <summary>
+ /// Returns this object back to the object pool.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Calling this when you are done with the object returns the object back to a pool in order to be reused.
+ /// This can reduce the amount of work the GC has to do dramatically but it is optional to call this.
+ /// </para>
+ /// <para>
+ /// Calling this indicates to Hazel that this can be reused and thus you should only call this when you are
+ /// completely finished with the object as the contents can be overwritten at any point after.
+ /// </para>
+ /// </remarks>
+ void Recycle();
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs
new file mode 100644
index 0000000..428c567
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs
@@ -0,0 +1,23 @@
+using System.Threading;
+
+namespace Hazel
+{
+ public class ListenerStatistics
+ {
+ private int _receiveThreadBlocked;
+ public int ReceiveThreadBlocked => this._receiveThreadBlocked;
+
+ private long _bytesSent;
+ public long BytesSent => this._bytesSent;
+
+ internal void AddReceiveThreadBlocking()
+ {
+ Interlocked.Increment(ref _receiveThreadBlocked);
+ }
+
+ internal void AddBytesSent(long bytes)
+ {
+ Interlocked.Add(ref _bytesSent, bytes);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/MessageReader.cs b/Tools/Hazel-Networking/Hazel/MessageReader.cs
new file mode 100644
index 0000000..bd3b0d8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/MessageReader.cs
@@ -0,0 +1,452 @@
+using System;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+namespace Hazel
+{
+ public class MessageReader : IRecyclable
+ {
+ public static readonly ObjectPool<MessageReader> ReaderPool = new ObjectPool<MessageReader>(() => new MessageReader());
+
+ public byte[] Buffer;
+ public byte Tag;
+
+ public int Length; // 总长度
+ public int Offset; // length和tag后面
+
+ public int BytesRemaining => this.Length - this.Position;
+
+ private MessageReader Parent;
+
+ public int Position
+ {
+ get { return this._position; }
+ set
+ {
+ this._position = value;
+ this.readHead = value + Offset;
+ }
+ }
+
+ private int _position;
+ private int readHead;
+
+ public static MessageReader GetSized(int minSize)
+ {
+ var output = ReaderPool.GetObject();
+
+ if (output.Buffer == null || output.Buffer.Length < minSize)
+ {
+ output.Buffer = new byte[minSize];
+ }
+ else
+ {
+ Array.Clear(output.Buffer, 0, output.Buffer.Length);
+ }
+
+ output.Offset = 0;
+ output.Position = 0;
+ output.Tag = byte.MaxValue;
+ return output;
+ }
+
+ public static MessageReader Get(byte[] buffer)
+ {
+ var output = ReaderPool.GetObject();
+
+ output.Buffer = buffer;
+ output.Offset = 0;
+ output.Position = 0;
+ output.Length = buffer.Length;
+ output.Tag = byte.MaxValue;
+
+ return output;
+ }
+
+ public static MessageReader CopyMessageIntoParent(MessageReader source)
+ {
+ var output = MessageReader.GetSized(source.Length + 3);
+ System.Buffer.BlockCopy(source.Buffer, source.Offset - 3, output.Buffer, 0, source.Length + 3);
+
+ output.Offset = 0;
+ output.Position = 0;
+ output.Length = source.Length + 3;
+
+ return output;
+ }
+
+ public static MessageReader Get(MessageReader source)
+ {
+ var output = MessageReader.GetSized(source.Buffer.Length);
+ System.Buffer.BlockCopy(source.Buffer, 0, output.Buffer, 0, source.Buffer.Length);
+
+ output.Offset = source.Offset;
+
+ output._position = source._position;
+ output.readHead = source.readHead;
+
+ output.Length = source.Length;
+ output.Tag = source.Tag;
+
+ return output;
+ }
+
+ public static MessageReader Get(byte[] buffer, int offset)
+ {
+ // Ensure there is at least a header
+ if (offset + 3 > buffer.Length) return null;
+
+ var output = ReaderPool.GetObject();
+
+ output.Buffer = buffer;
+ output.Offset = offset;
+ output.Position = 0;
+
+ output.Length = output.ReadUInt16();
+ output.Tag = output.ReadByte();
+
+ output.Offset += 3;
+ output.Position = 0;
+
+ return output;
+ }
+
+ /// <summary>
+ /// Produces a MessageReader using the parent's buffer. This MessageReader should **NOT** be recycled.
+ /// </summary>
+ public MessageReader ReadMessage()
+ {
+ // Ensure there is at least a header
+ if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+
+ var output = new MessageReader();
+
+ output.Parent = this;
+ output.Buffer = this.Buffer;
+ output.Offset = this.readHead;
+ output.Position = 0;
+
+ output.Length = output.ReadUInt16();
+ output.Tag = output.ReadByte();
+
+ output.Offset += 3;
+ output.Position = 0;
+
+ if (this.BytesRemaining < output.Length + 3) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {output.Length + 3} of {this.BytesRemaining}");
+
+ this.Position += output.Length + 3;
+ return output;
+ }
+
+ /// <summary>
+ /// Produces a MessageReader with a new buffer. This MessageReader should be recycled.
+ /// </summary>
+ public MessageReader ReadMessageAsNewBuffer()
+ {
+ if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+
+ var len = this.ReadUInt16();
+ var tag = this.ReadByte();
+
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {len} of {this.BytesRemaining}");
+
+ var output = MessageReader.GetSized(len);
+
+ Array.Copy(this.Buffer, this.readHead, output.Buffer, 0, len);
+
+ output.Length = len;
+ output.Tag = tag;
+
+ this.Position += output.Length;
+ return output;
+ }
+
+ public MessageWriter StartWriter()
+ {
+ var output = new MessageWriter(this.Buffer);
+ output.Position = this.readHead;
+ return output;
+ }
+
+ public MessageReader Duplicate()
+ {
+ var output = GetSized(this.Length);
+ Array.Copy(this.Buffer, this.Offset, output.Buffer, 0, this.Length);
+ output.Length = this.Length;
+ output.Offset = 0;
+ output.Position = 0;
+
+ return output;
+ }
+
+ public void RemoveMessage(MessageReader reader)
+ {
+ var temp = MessageReader.GetSized(reader.Buffer.Length);
+ try
+ {
+ var headerOffset = reader.Offset - 3;
+ var endOfMessage = reader.Offset + reader.Length;
+ var len = reader.Buffer.Length - endOfMessage;
+
+ Array.Copy(reader.Buffer, endOfMessage, temp.Buffer, 0, len);
+ Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset, len);
+
+ this.AdjustLength(reader.Offset, reader.Length + 3);
+ }
+ finally
+ {
+ temp.Recycle();
+ }
+ }
+
+ public void InsertMessage(MessageReader reader, MessageWriter writer)
+ {
+ var temp = MessageReader.GetSized(reader.Buffer.Length);
+ try
+ {
+ var headerOffset = reader.Offset - 3;
+ var startOfMessage = reader.Offset;
+ var len = reader.Buffer.Length - startOfMessage;
+ int writerOffset = 3;
+ switch (writer.SendOption)
+ {
+ case SendOption.Reliable:
+ writerOffset = 3;
+ break;
+ case SendOption.None:
+ writerOffset = 1;
+ break;
+ }
+
+ //store the original buffer in temp
+ Array.Copy(reader.Buffer, headerOffset, temp.Buffer, 0, len);
+
+ //put the contents of writer in at headerOffset
+ Array.Copy(writer.Buffer, writerOffset, this.Buffer, headerOffset, writer.Length-writerOffset);
+
+ //put the original buffer in after that
+ Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset + (writer.Length-writerOffset), len - writer.Length);
+
+ this.AdjustLength(-1 * reader.Offset , -1 * (writer.Length - writerOffset));
+ }
+ finally
+ {
+ temp.Recycle();
+ }
+ }
+
+ private void AdjustLength(int offset, int amount)
+ {
+ if (this.readHead > offset)
+ {
+ this.Position -= amount;
+ }
+
+ if (Parent != null)
+ {
+ var lengthOffset = this.Offset - 3;
+ var curLen = this.Buffer[lengthOffset]
+ | (this.Buffer[lengthOffset + 1] << 8);
+
+ curLen -= amount;
+ this.Length -= amount;
+
+ this.Buffer[lengthOffset] = (byte)curLen;
+ this.Buffer[lengthOffset + 1] = (byte)(this.Buffer[lengthOffset + 1] >> 8);
+
+ Parent.AdjustLength(offset, amount);
+ }
+ }
+
+ public void Recycle()
+ {
+ this.Parent = null;
+ ReaderPool.PutObject(this);
+ }
+
+ #region Read Methods
+ public bool ReadBoolean()
+ {
+ byte val = this.FastByte();
+ return val != 0;
+ }
+
+ public sbyte ReadSByte()
+ {
+ return (sbyte)this.FastByte();
+ }
+
+ public byte ReadByte()
+ {
+ return this.FastByte();
+ }
+
+ public ushort ReadUInt16()
+ {
+ ushort output =
+ (ushort)(this.FastByte()
+ | this.FastByte() << 8);
+ return output;
+ }
+
+ public short ReadInt16()
+ {
+ short output =
+ (short)(this.FastByte()
+ | this.FastByte() << 8);
+ return output;
+ }
+
+ public uint ReadUInt32()
+ {
+ uint output = this.FastByte()
+ | (uint)this.FastByte() << 8
+ | (uint)this.FastByte() << 16
+ | (uint)this.FastByte() << 24;
+
+ return output;
+ }
+
+ public int ReadInt32()
+ {
+ int output = this.FastByte()
+ | this.FastByte() << 8
+ | this.FastByte() << 16
+ | this.FastByte() << 24;
+
+ return output;
+ }
+
+ public ulong ReadUInt64()
+ {
+ ulong output = (ulong)this.FastByte()
+ | (ulong)this.FastByte() << 8
+ | (ulong)this.FastByte() << 16
+ | (ulong)this.FastByte() << 24
+ | (ulong)this.FastByte() << 32
+ | (ulong)this.FastByte() << 40
+ | (ulong)this.FastByte() << 48
+ | (ulong)this.FastByte() << 56;
+
+ return output;
+ }
+
+ public long ReadInt64()
+ {
+ long output = (long)this.FastByte()
+ | (long)this.FastByte() << 8
+ | (long)this.FastByte() << 16
+ | (long)this.FastByte() << 24
+ | (long)this.FastByte() << 32
+ | (long)this.FastByte() << 40
+ | (long)this.FastByte() << 48
+ | (long)this.FastByte() << 56;
+
+ return output;
+ }
+
+ public unsafe float ReadSingle()
+ {
+ float output = 0;
+ fixed (byte* bufPtr = &this.Buffer[this.readHead])
+ {
+ byte* outPtr = (byte*)&output;
+
+ *outPtr = *bufPtr;
+ *(outPtr + 1) = *(bufPtr + 1);
+ *(outPtr + 2) = *(bufPtr + 2);
+ *(outPtr + 3) = *(bufPtr + 3);
+ }
+
+ this.Position += 4;
+ return output;
+ }
+
+ public string ReadString()
+ {
+ int len = this.ReadPackedInt32();
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}");
+
+ string output = UTF8Encoding.UTF8.GetString(this.Buffer, this.readHead, len);
+
+ this.Position += len;
+ return output;
+ }
+
+ public byte[] ReadBytesAndSize()
+ {
+ int len = this.ReadPackedInt32();
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}");
+
+ return this.ReadBytes(len);
+ }
+
+ public byte[] ReadBytes(int length)
+ {
+ if (this.BytesRemaining < length) throw new InvalidDataException($"Read length is longer than message length: {length} of {this.BytesRemaining}");
+
+ byte[] output = new byte[length];
+ Array.Copy(this.Buffer, this.readHead, output, 0, output.Length);
+ this.Position += output.Length;
+ return output;
+ }
+
+ ///
+ public int ReadPackedInt32()
+ {
+ return (int)this.ReadPackedUInt32();
+ }
+
+ ///
+ public uint ReadPackedUInt32()
+ {
+ bool readMore = true;
+ int shift = 0;
+ uint output = 0;
+
+ while (readMore)
+ {
+ if (this.BytesRemaining < 1) throw new InvalidDataException($"Read length is longer than message length.");
+
+ byte b = this.ReadByte();
+ if (b >= 0x80)
+ {
+ readMore = true;
+ b ^= 0x80;
+ }
+ else
+ {
+ readMore = false;
+ }
+
+ output |= (uint)(b << shift);
+ shift += 7;
+ }
+
+ return output;
+ }
+ #endregion
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private byte FastByte()
+ {
+ this._position++;
+ return this.Buffer[this.readHead++];
+ }
+
+ public unsafe static bool IsLittleEndian()
+ {
+ byte b;
+ unsafe
+ {
+ int i = 1;
+ byte* bp = (byte*)&i;
+ b = *bp;
+ }
+
+ return b == 1;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/MessageWriter.cs b/Tools/Hazel-Networking/Hazel/MessageWriter.cs
new file mode 100644
index 0000000..9caaaf2
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/MessageWriter.cs
@@ -0,0 +1,365 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// 嵌套结构的Message
+ /// 结构:
+ /// ------------------------------------
+ /// 2bytes (ushort) 包长度
+ /// 1bytes (tag) 协议ID,在AmongUS里是tags.cs里定义的tag和subtag
+ /// ------------------------------------
+ /// 数据 包括嵌套的子协议
+ /// ------------------------------------
+ /// </summary>
+ public class MessageWriter : IRecyclable
+ {
+ public static int BufferSize = 64000;
+ public static readonly ObjectPool<MessageWriter> WriterPool = new ObjectPool<MessageWriter>(() => new MessageWriter(BufferSize));
+
+ public byte[] Buffer;
+ public int Length; // 总长度
+ public int Position; // 写入游标
+
+ public SendOption SendOption { get; private set; }
+
+ private Stack<int> messageStarts = new Stack<int>();
+
+ public MessageWriter(byte[] buffer)
+ {
+ this.Buffer = buffer;
+ this.Length = this.Buffer.Length;
+ }
+
+ ///
+ public MessageWriter(int bufferSize)
+ {
+ this.Buffer = new byte[bufferSize];
+ }
+
+ /// <summary>
+ /// 去掉header
+ /// </summary>
+ /// <param name="includeHeader"></param>
+ /// <returns></returns>
+ /// <exception cref="NotImplementedException"></exception>
+ public byte[] ToByteArray(bool includeHeader)
+ {
+ if (includeHeader)
+ {
+ byte[] output = new byte[this.Length];
+ System.Buffer.BlockCopy(this.Buffer, 0, output, 0, this.Length);
+ return output;
+ }
+ else
+ {
+ switch (this.SendOption)
+ {
+ case SendOption.Reliable:
+ {
+ byte[] output = new byte[this.Length - 3];
+ System.Buffer.BlockCopy(this.Buffer, 3, output, 0, this.Length - 3);
+ return output;
+ }
+ case SendOption.None:
+ {
+ byte[] output = new byte[this.Length - 1];
+ System.Buffer.BlockCopy(this.Buffer, 1, output, 0, this.Length - 1);
+ return output;
+ }
+ }
+ }
+
+ throw new NotImplementedException();
+ }
+
+ ///
+ /// <param name="sendOption">The option specifying how the message should be sent.</param>
+ public static MessageWriter Get(SendOption sendOption = SendOption.None)
+ {
+ var output = WriterPool.GetObject();
+ output.Clear(sendOption);
+
+ return output;
+ }
+
+ public bool HasBytes(int expected)
+ {
+ if (this.SendOption == SendOption.None)
+ {
+ return this.Length > 1 + expected;
+ }
+
+ return this.Length > 3 + expected;
+ }
+
+ ///
+ public void StartMessage(byte typeFlag)
+ {
+ var messageStart = this.Position;
+ messageStarts.Push(messageStart);
+ this.Buffer[messageStart] = 0;
+ this.Buffer[messageStart + 1] = 0;
+ this.Position += 2;
+ this.Write(typeFlag);
+ }
+
+ ///
+ public void EndMessage()
+ {
+ var lastMessageStart = messageStarts.Pop();
+ ushort length = (ushort)(this.Position - lastMessageStart - 3); // Minus length and type byte
+ this.Buffer[lastMessageStart] = (byte)length;
+ this.Buffer[lastMessageStart + 1] = (byte)(length >> 8);
+ }
+
+ ///
+ public void CancelMessage()
+ {
+ this.Position = this.messageStarts.Pop();
+ this.Length = this.Position;
+ }
+
+ public void Clear(SendOption sendOption)
+ {
+ Array.Clear(this.Buffer, 0, this.Buffer.Length);
+ this.messageStarts.Clear();
+ this.SendOption = sendOption;
+ this.Buffer[0] = (byte)sendOption;
+ switch (sendOption)
+ {
+ default:
+ case SendOption.None:
+ this.Length = this.Position = 1;
+ break;
+ case SendOption.Reliable:
+ this.Length = this.Position = 3;
+ break;
+ }
+ }
+
+ ///
+ public void Recycle()
+ {
+ this.Position = this.Length = 0;
+ WriterPool.PutObject(this);
+ }
+
+ #region WriteMethods
+
+ public void CopyFrom(MessageReader target)
+ {
+ int offset, length;
+ if (target.Tag == byte.MaxValue)
+ {
+ offset = target.Offset;
+ length = target.Length;
+ }
+ else
+ {
+ offset = target.Offset - 3;
+ length = target.Length + 3;
+ }
+
+ System.Buffer.BlockCopy(target.Buffer, offset, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(bool value)
+ {
+ this.Buffer[this.Position++] = (byte)(value ? 1 : 0);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(sbyte value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte value)
+ {
+ this.Buffer[this.Position++] = value;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(short value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(ushort value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(uint value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(int value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(ulong value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ this.Buffer[this.Position++] = (byte)(value >> 32);
+ this.Buffer[this.Position++] = (byte)(value >> 40);
+ this.Buffer[this.Position++] = (byte)(value >> 48);
+ this.Buffer[this.Position++] = (byte)(value >> 56);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(long value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ this.Buffer[this.Position++] = (byte)(value >> 32);
+ this.Buffer[this.Position++] = (byte)(value >> 40);
+ this.Buffer[this.Position++] = (byte)(value >> 48);
+ this.Buffer[this.Position++] = (byte)(value >> 56);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public unsafe void Write(float value)
+ {
+ fixed (byte* ptr = &this.Buffer[this.Position])
+ {
+ byte* valuePtr = (byte*)&value;
+
+ *ptr = *valuePtr;
+ *(ptr + 1) = *(valuePtr + 1);
+ *(ptr + 2) = *(valuePtr + 2);
+ *(ptr + 3) = *(valuePtr + 3);
+ }
+
+ this.Position += 4;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(string value)
+ {
+ var bytes = UTF8Encoding.UTF8.GetBytes(value);
+ this.WritePacked(bytes.Length);
+ this.Write(bytes);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes)
+ {
+ this.WritePacked((uint)bytes.Length);
+ this.Write(bytes);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes, int length)
+ {
+ this.WritePacked((uint)length);
+ this.Write(bytes, length);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes, int offset, int length)
+ {
+ this.WritePacked((uint)length);
+ this.Write(bytes, offset, length);
+ }
+
+ public void Write(byte[] bytes)
+ {
+ Array.Copy(bytes, 0, this.Buffer, this.Position, bytes.Length);
+ this.Position += bytes.Length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte[] bytes, int offset, int length)
+ {
+ Array.Copy(bytes, offset, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte[] bytes, int length)
+ {
+ Array.Copy(bytes, 0, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ ///
+ public void WritePacked(int value)
+ {
+ this.WritePacked((uint)value);
+ }
+
+ ///
+ public void WritePacked(uint value)
+ {
+ do
+ {
+ byte b = (byte)(value & 0xFF);
+ if (value >= 0x80)
+ {
+ b |= 0x80;
+ }
+
+ this.Write(b);
+ value >>= 7;
+ } while (value > 0);
+ }
+ #endregion
+
+ public void Write(MessageWriter msg, bool includeHeader)
+ {
+ int offset = 0;
+ if (!includeHeader)
+ {
+ switch (msg.SendOption)
+ {
+ case SendOption.None:
+ offset = 1;
+ break;
+ case SendOption.Reliable:
+ offset = 3;
+ break;
+ }
+ }
+
+ this.Write(msg.Buffer, offset, msg.Length - offset);
+ }
+
+ public unsafe static bool IsLittleEndian()
+ {
+ byte b;
+ unsafe
+ {
+ int i = 1;
+ byte* bp = (byte*)&i;
+ b = *bp;
+ }
+
+ return b == 1;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnection.cs b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs
new file mode 100644
index 0000000..d1de8a8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs
@@ -0,0 +1,117 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Text;
+
+
+namespace Hazel
+{
+ public enum HazelInternalErrors
+ {
+ SocketExceptionSend,
+ SocketExceptionReceive,
+ ReceivedZeroBytes,
+ PingsWithoutResponse,
+ ReliablePacketWithoutResponse,
+ ConnectionDisconnected,
+ DtlsNegotiationFailed
+ }
+
+ /// <summary>
+ /// Abstract base class for a <see cref="Connection"/> to a remote end point via a network protocol like TCP or UDP.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class NetworkConnection : Connection
+ {
+ /// <summary>
+ /// An event that gives us a chance to send well-formed disconnect messages to clients when an internal disconnect happens.
+ /// </summary>
+ public Func<HazelInternalErrors, MessageWriter> OnInternalDisconnect;
+
+ public virtual float AveragePingMs { get; }
+
+ public long GetIP4Address()
+ {
+ if (IPMode == IPMode.IPv4)
+ {
+ return this.EndPoint.Address.Address;
+ }
+ else
+ {
+ var bytes = this.EndPoint.Address.GetAddressBytes();
+ return BitConverter.ToInt64(bytes, bytes.Length - 8);
+ }
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected abstract bool SendDisconnect(MessageWriter writer);
+
+ /// <summary>
+ /// Called when the socket has been disconnected at the remote host.
+ /// </summary>
+ protected void DisconnectRemote(string reason, MessageReader reader)
+ {
+ if (this.SendDisconnect(null))
+ {
+ try
+ {
+ InvokeDisconnected(reason, reader);
+ }
+ catch { }
+ }
+
+ this.Dispose();
+ }
+
+ /// <summary>
+ /// Called when socket is disconnected internally
+ /// </summary>
+ internal void DisconnectInternal(HazelInternalErrors error, string reason)
+ {
+ var handler = this.OnInternalDisconnect;
+ if (handler != null)
+ {
+ MessageWriter messageToRemote = handler(error);
+ if (messageToRemote != null)
+ {
+ try
+ {
+ Disconnect(reason, messageToRemote);
+ }
+ finally
+ {
+ messageToRemote.Recycle();
+ }
+ }
+ else
+ {
+ Disconnect(reason);
+ }
+ }
+ else
+ {
+ Disconnect(reason);
+ }
+ }
+
+ /// <summary>
+ /// Called when the socket has been disconnected locally.
+ /// </summary>
+ public override void Disconnect(string reason, MessageWriter writer = null)
+ {
+ if (this.SendDisconnect(writer))
+ {
+ try
+ {
+ InvokeDisconnected(reason, null);
+ }
+ catch { }
+ }
+
+ this.Dispose();
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs
new file mode 100644
index 0000000..af26c4c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Text;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Abstract base class for a <see cref="ConnectionListener"/> for network based connections.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class NetworkConnectionListener : ConnectionListener
+ {
+ /// <summary>
+ /// The local end point the listener is listening for new clients on.
+ /// </summary>
+ public IPEndPoint EndPoint { get; protected set; }
+
+ /// <summary>
+ /// The <see cref="IPMode">IPMode</see> the listener is listening for new clients on.
+ /// </summary>
+ public IPMode IPMode { get; protected set; }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs
new file mode 100644
index 0000000..c3fd62f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs
@@ -0,0 +1,22 @@
+namespace Hazel
+{
+ public struct NewConnectionEventArgs
+ {
+ /// <summary>
+ /// The data received from the client in the handshake.
+ /// You must not recycle this. If you need the message outside of a callback, you should copy it.
+ /// </summary>
+ public readonly MessageReader HandshakeData;
+
+ /// <summary>
+ /// The <see cref="Connection"/> to the new client.
+ /// </summary>
+ public readonly Connection Connection;
+
+ public NewConnectionEventArgs(MessageReader handshakeData, Connection connection)
+ {
+ this.HandshakeData = handshakeData;
+ this.Connection = connection;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ObjectPool.cs b/Tools/Hazel-Networking/Hazel/ObjectPool.cs
new file mode 100644
index 0000000..510e55a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ObjectPool.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Threading;
+
+namespace Hazel
+{
+ /// <summary>
+ /// A fairly simple object pool for items that will be created a lot.
+ /// </summary>
+ /// <typeparam name="T">The type that is pooled.</typeparam>
+ /// <threadsafety static="true" instance="true"/>
+ public sealed class ObjectPool<T> where T : IRecyclable
+ {
+ private int numberCreated;
+ public int NumberCreated { get { return numberCreated; } }
+
+ public int NumberInUse { get { return this.inuse.Count; } }
+ public int NumberNotInUse { get { return this.pool.Count; } }
+ public int Size { get { return this.NumberInUse + this.NumberNotInUse; } }
+
+#if HAZEL_BAG
+ private readonly ConcurrentBag<T> pool = new ConcurrentBag<T>();
+#else
+ private readonly List<T> pool = new List<T>();
+#endif
+
+ // Unavailable objects
+ private readonly ConcurrentDictionary<T, bool> inuse = new ConcurrentDictionary<T, bool>();
+
+ /// <summary>
+ /// The generator for creating new objects.
+ /// </summary>
+ /// <returns></returns>
+ private readonly Func<T> objectFactory;
+
+ /// <summary>
+ /// Internal constructor for our ObjectPool.
+ /// </summary>
+ internal ObjectPool(Func<T> objectFactory)
+ {
+ this.objectFactory = objectFactory;
+ }
+
+ /// <summary>
+ /// Returns a pooled object of type T, if none are available another is created.
+ /// </summary>
+ /// <returns>An instance of T.</returns>
+ internal T GetObject()
+ {
+#if HAZEL_BAG
+ if (!pool.TryTake(out T item))
+ {
+ Interlocked.Increment(ref numberCreated);
+ item = objectFactory.Invoke();
+ }
+#else
+ T item;
+ lock (this.pool)
+ {
+ if (this.pool.Count > 0)
+ {
+ var idx = this.pool.Count - 1;
+ item = this.pool[idx];
+ this.pool.RemoveAt(idx);
+ }
+ else
+ {
+ Interlocked.Increment(ref numberCreated);
+ item = objectFactory.Invoke();
+ }
+ }
+#endif
+
+ if (!inuse.TryAdd(item, true))
+ {
+ throw new Exception("Duplicate pull " + typeof(T).Name);
+ }
+
+ return item;
+ }
+
+ /// <summary>
+ /// Returns an object to the pool.
+ /// </summary>
+ /// <param name="item">The item to return.</param>
+ internal void PutObject(T item)
+ {
+ if (inuse.TryRemove(item, out bool b))
+ {
+#if HAZEL_BAG
+ pool.Add(item);
+#else
+ lock (this.pool)
+ {
+ pool.Add(item);
+ }
+#endif
+ }
+ else
+ {
+#if DEBUG
+ throw new Exception("Duplicate add " + typeof(T).Name);
+#endif
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/SendErrors.cs b/Tools/Hazel-Networking/Hazel/SendErrors.cs
new file mode 100644
index 0000000..6871c6a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/SendErrors.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ [Flags]
+ public enum SendErrors
+ {
+ None,
+ Disconnected,
+ Unknown
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/SendOption.cs b/Tools/Hazel-Networking/Hazel/SendOption.cs
new file mode 100644
index 0000000..c2ffb22
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/SendOption.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Specifies how a message should be sent between connections.
+ /// </summary>
+ [Flags]
+ public enum SendOption : byte
+ {
+ /// <summary>
+ /// Requests unreliable delivery with no framentation.
+ /// </summary>
+ /// <remarks>
+ /// Sending data using unreliable delivery means that data is not guaranteed to arrive at it's destination nor is
+ /// it guarenteed to arrive only once. However, unreliable delivery can be faster than other methods and it
+ /// typically requires a smaller number of protocol bytes than other methods. There is also typically less
+ /// processing involved and less memory needed as packets are not stored once sent.
+ /// </remarks>
+ None = 0,
+
+ /// <summary>
+ /// Requests data be sent reliably but with no fragmentation.
+ /// </summary>
+ /// <remarks>
+ /// Sending data reliably means that data is guarenteed to arrive and to arrive only once. Reliable delivery
+ /// typically requires more processing, more memory (as packets need to be stored in case they need resending),
+ /// a larger number of protocol bytes and can be slower than unreliable delivery.
+ /// </remarks>
+ Reliable = 1,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs
new file mode 100644
index 0000000..3c7abcf
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs
@@ -0,0 +1,65 @@
+using System;
+
+namespace Hazel
+{
+ public interface ILogger
+ {
+ void WriteVerbose(string msg);
+ void WriteError(string msg);
+ void WriteWarning(string msg);
+ void WriteInfo(string msg);
+ }
+
+ public class NullLogger : ILogger
+ {
+ public static readonly NullLogger Instance = new NullLogger();
+
+ public void WriteVerbose(string msg)
+ {
+ }
+
+ public void WriteError(string msg)
+ {
+ }
+
+ public void WriteWarning(string msg)
+ {
+ }
+
+ public void WriteInfo(string msg)
+ {
+ }
+ }
+
+ public class ConsoleLogger : ILogger
+ {
+ private bool verbose;
+ public ConsoleLogger(bool verbose)
+ {
+ this.verbose = verbose;
+ }
+
+ public void WriteVerbose(string msg)
+ {
+ if (this.verbose)
+ {
+ Console.WriteLine($"{DateTime.Now} [VERBOSE] {msg}");
+ }
+ }
+
+ public void WriteWarning(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [WARN] {msg}");
+ }
+
+ public void WriteError(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [ERROR] {msg}");
+ }
+
+ public void WriteInfo(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [INFO] {msg}");
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs
new file mode 100644
index 0000000..d856823
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs
@@ -0,0 +1,158 @@
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.NetworkInformation;
+using System.Net.Sockets;
+
+namespace Hazel.UPnP
+{
+ internal class NetUtility
+ {
+ private static IList<NetworkInterface> GetValidNetworkInterfaces()
+ {
+ var nics = NetworkInterface.GetAllNetworkInterfaces();
+ if (nics == null || nics.Length < 1)
+ return new NetworkInterface[0];
+
+ var validInterfaces = new List<NetworkInterface>(nics.Length);
+
+ NetworkInterface best = null;
+ foreach (NetworkInterface adapter in nics)
+ {
+ if (adapter.NetworkInterfaceType == NetworkInterfaceType.Loopback || adapter.NetworkInterfaceType == NetworkInterfaceType.Unknown)
+ continue;
+ if (!adapter.Supports(NetworkInterfaceComponent.IPv4) && !adapter.Supports(NetworkInterfaceComponent.IPv6))
+ continue;
+ if (best == null)
+ best = adapter;
+ if (adapter.OperationalStatus != OperationalStatus.Up)
+ continue;
+
+ // make sure this adapter has any ip addresses
+ IPInterfaceProperties properties = adapter.GetIPProperties();
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null)
+ {
+ // Yes it does, add this network interface.
+ validInterfaces.Add(adapter);
+ break;
+ }
+ }
+ }
+
+ if (validInterfaces.Count == 0 && best != null)
+ validInterfaces.Add(best);
+
+ return validInterfaces;
+ }
+
+ /// <summary>
+ /// Gets the addresses from all active network interfaces, but at most one per interface.
+ /// </summary>
+ /// <param name="addressFamily">The <see cref="AddressFamily"/> of the addresses to return</param>
+ /// <returns>An <see cref="ICollection{T}"/> of <see cref="UnicastIPAddressInformation"/>.</returns>
+ public static ICollection<UnicastIPAddressInformation> GetAddressesFromNetworkInterfaces(AddressFamily addressFamily)
+ {
+ var unicastAddresses = new List<UnicastIPAddressInformation>();
+
+ foreach (NetworkInterface adapter in GetValidNetworkInterfaces())
+ {
+ IPInterfaceProperties properties = adapter.GetIPProperties();
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == addressFamily)
+ {
+ unicastAddresses.Add(unicastAddress);
+ break;
+ }
+ }
+ }
+
+ return unicastAddresses;
+ }
+
+ /// <summary>
+ /// Gets my local IPv4 address (not necessarily external) and subnet mask
+ /// </summary>
+ public static IPAddress GetMyAddress(out IPAddress mask)
+ {
+ var networkInterfaces = GetValidNetworkInterfaces();
+ IPInterfaceProperties properties = null;
+
+ if (networkInterfaces.Count > 0)
+ properties = networkInterfaces[0]?.GetIPProperties();
+
+ if (properties != null)
+ {
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork)
+ {
+ mask = unicastAddress.IPv4Mask;
+ return unicastAddress.Address;
+ }
+ }
+ }
+
+ mask = null;
+ return null;
+ }
+
+ /// <summary>
+ /// Gets the broadcast address for the first network interface or, if not able to,
+ /// the limited broadcast address.
+ /// </summary>
+ /// <returns>An <see cref="IPAddress"/> for broadcasting.</returns>
+ public static IPAddress GetBroadcastAddress()
+ {
+ var networkInterfaces = GetValidNetworkInterfaces();
+ IPInterfaceProperties properties = null;
+
+ if (networkInterfaces.Count > 0)
+ properties = networkInterfaces[0]?.GetIPProperties();
+
+ if (properties != null)
+ {
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ IPAddress ipAddress = GetBroadcastAddress(unicastAddress);
+ if (ipAddress != null)
+ {
+ return ipAddress;
+ }
+ }
+ }
+
+ return IPAddress.Broadcast;
+ }
+
+ /// <summary>
+ /// Gets the broadcast address for the given <paramref name="unicastAddress"/>.
+ /// </summary>
+ /// <param name="unicastAddress">A <see cref="UnicastIPAddressInformation"/></param>
+ /// <returns>An <see cref="IPAddress"/> for broadcasting, null if the <paramref name="unicastAddress"/>
+ /// is not an IPv4 address.</returns>
+ public static IPAddress GetBroadcastAddress(UnicastIPAddressInformation unicastAddress)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork)
+ {
+ var mask = unicastAddress.IPv4Mask;
+ byte[] ipAdressBytes = unicastAddress.Address.GetAddressBytes();
+ byte[] subnetMaskBytes = mask.GetAddressBytes();
+
+ if (ipAdressBytes.Length != subnetMaskBytes.Length)
+ throw new ArgumentException("Lengths of IP address and subnet mask do not match.");
+
+ byte[] broadcastAddress = new byte[ipAdressBytes.Length];
+ for (int i = 0; i < broadcastAddress.Length; i++)
+ {
+ broadcastAddress[i] = (byte)(ipAdressBytes[i] | (subnetMaskBytes[i] ^ 255));
+ }
+ return new IPAddress(broadcastAddress);
+ }
+
+ return null;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs
new file mode 100644
index 0000000..771709e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs
@@ -0,0 +1,347 @@
+using System;
+using System.IO;
+using System.Xml;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.UPnP
+{
+ /// <summary>
+ /// Status of the UPnP capabilities
+ /// </summary>
+ public enum UPnPStatus
+ {
+ /// <summary>
+ /// Still discovering UPnP capabilities
+ /// </summary>
+ Discovering,
+
+ /// <summary>
+ /// UPnP is not available
+ /// </summary>
+ NotAvailable,
+
+ /// <summary>
+ /// UPnP is available and ready to use
+ /// </summary>
+ Available
+ }
+
+ public class UPnPHelper : IDisposable
+ {
+ private const int DiscoveryTimeOutMs = 1000;
+
+ private string serviceUrl;
+ private string serviceName = "";
+
+ private ManualResetEvent discoveryComplete = new ManualResetEvent(false);
+ private Socket socket;
+
+ private DateTime discoveryResponseDeadline;
+
+ private EndPoint ep;
+ private byte[] buffer;
+
+ private ILogger logger;
+
+ /// <summary>
+ /// Status of the UPnP capabilities of this NetPeer
+ /// </summary>
+ public UPnPStatus Status { get; private set; }
+
+ public UPnPHelper(ILogger logger)
+ {
+ this.logger = logger;
+
+ this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.socket.EnableBroadcast = true;
+ this.socket.MulticastLoopback = false;
+
+ this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, 1);
+ this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+
+ this.ep = new IPEndPoint(IPAddress.Any, 1900);
+ this.buffer = new byte[ushort.MaxValue];
+
+ ListenForUPnP();
+
+ this.Discover();
+ }
+
+ private void ListenForUPnP()
+ {
+ try
+ {
+ socket.BeginReceiveFrom(this.buffer, 0, this.buffer.Length, SocketFlags.None, ref ep, HandleMessage, null);
+ }
+ catch(Exception e)
+ {
+ this.logger.WriteInfo("Exception listening for UPnP: " + e.Message);
+ }
+ }
+
+ private void HandleMessage(IAsyncResult ar)
+ {
+ int len;
+ try
+ {
+ len = this.socket.EndReceiveFrom(ar, ref ep);
+ }
+ catch
+ {
+ return;
+ }
+
+ string resp = System.Text.Encoding.UTF8.GetString(buffer, 0, len);
+ if (resp.Contains("upnp:rootdevice") || resp.Contains("UPnP/1.0"))
+ {
+ var locationStart = resp.IndexOf("location:", StringComparison.OrdinalIgnoreCase);
+ if (locationStart >= 0)
+ {
+ locationStart += 10;
+ var locationEnd = resp.IndexOf("\r", locationStart);
+
+ resp = resp.Substring(locationStart, locationEnd - locationStart);
+ if (!ExtractServiceUrl(resp))
+ {
+ ListenForUPnP();
+ }
+ }
+ else
+ {
+ ListenForUPnP();
+ }
+ }
+ else
+ {
+ ListenForUPnP();
+ }
+ }
+
+ internal void Discover()
+ {
+ string str =
+"M-SEARCH * HTTP/1.1\r\n" +
+"HOST: 239.255.255.250:1900\r\n" +
+"ST:upnp:rootdevice\r\n" +
+"MAN:\"ssdp:discover\"\r\n" +
+"MX:3\r\n\r\n";
+
+ discoveryResponseDeadline = DateTime.UtcNow.AddSeconds(6);
+ Status = UPnPStatus.Discovering;
+
+ byte[] buffer = System.Text.Encoding.UTF8.GetBytes(str);
+
+ this.logger.WriteInfo("Attempting UPnP discovery");
+
+ socket.SendTo(buffer, new IPEndPoint(NetUtility.GetBroadcastAddress(), 1900));
+ }
+
+ internal bool ExtractServiceUrl(string resp)
+ {
+ try
+ {
+ XmlDocument desc = new XmlDocument();
+ using (var response = WebRequest.Create(resp).GetResponse())
+ {
+ desc.Load(response.GetResponseStream());
+ }
+
+ XmlNamespaceManager nsMgr = new XmlNamespaceManager(desc.NameTable);
+ nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0");
+ XmlNode typen = desc.SelectSingleNode("//tns:device/tns:deviceType/text()", nsMgr);
+ if (!typen.Value.Contains("InternetGatewayDevice"))
+ return false;
+
+ serviceName = "WANIPConnection";
+ XmlNode node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr);
+ if (node == null)
+ {
+ //try another service name
+ serviceName = "WANPPPConnection";
+ node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr);
+ if (node == null)
+ return false;
+ }
+
+ serviceUrl = CombineUrls(resp, node.Value);
+ this.logger.WriteInfo("UPnP service ready");
+ Status = UPnPStatus.Available;
+ discoveryComplete.Set();
+ return true;
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("Exception while parsing UPnP Service URL: " + e.Message);
+ return false;
+ }
+ }
+
+ private static string CombineUrls(string gatewayURL, string subURL)
+ {
+ // Is Control URL an absolute URL?
+ if (subURL.Contains("http:") || subURL.Contains("."))
+ return subURL;
+
+ gatewayURL = gatewayURL.Replace("http://", ""); // strip any protocol
+ int n = gatewayURL.IndexOf("/");
+ if (n >= 0)
+ {
+ gatewayURL = gatewayURL.Substring(0, n); // Use first portion of URL
+ }
+
+ return "http://" + gatewayURL + subURL;
+ }
+
+ private bool CheckAvailability()
+ {
+ switch (Status)
+ {
+ case UPnPStatus.NotAvailable:
+ return false;
+ case UPnPStatus.Available:
+ return true;
+ case UPnPStatus.Discovering:
+ while (!discoveryComplete.WaitOne(DiscoveryTimeOutMs))
+ {
+ if (DateTime.UtcNow > discoveryResponseDeadline)
+ {
+ Status = UPnPStatus.NotAvailable;
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Add a forwarding rule to the router using UPnP
+ /// </summary>
+ /// <param name="externalPort">The external, WAN facing, port</param>
+ /// <param name="description">A description for the port forwarding rule</param>
+ /// <param name="internalPort">The port on the client machine to send traffic to</param>
+ /// <param name="durationSeconds">The lease duration on the port forwarding rule, in seconds. 0 for indefinite.</param>
+ public bool ForwardPort(int externalPort, string description, int internalPort = 0, int durationSeconds = 0)
+ {
+ if (!CheckAvailability())
+ return false;
+
+ if (internalPort == 0)
+ internalPort = externalPort;
+
+ try
+ {
+ var client = NetUtility.GetMyAddress(out _);
+ if (client == null)
+ return false;
+
+ SOAPRequest(serviceUrl,
+ $"<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" +
+ "<NewRemoteHost></NewRemoteHost>" +
+ $"<NewExternalPort>{externalPort}</NewExternalPort>" +
+ "<NewProtocol>UDP</NewProtocol>" +
+ $"<NewInternalPort>{internalPort}</NewInternalPort>" +
+ $"<NewInternalClient>{client}</NewInternalClient>" +
+ "<NewEnabled>1</NewEnabled>" +
+ $"<NewPortMappingDescription>{description}</NewPortMappingDescription>" +
+ $"<NewLeaseDuration>{durationSeconds}</NewLeaseDuration>" +
+ "</u:AddPortMapping>",
+ "AddPortMapping");
+
+ this.logger.WriteInfo("Sent UPnP port forward request.");
+ return true;
+ }
+ catch (Exception ex)
+ {
+ this.logger.WriteError("UPnP port forward failed: " + ex.Message);
+ return false;
+ }
+ }
+
+ /// <summary>
+ /// Delete a forwarding rule from the router using UPnP
+ /// </summary>
+ /// <param name="externalPort">The external, 'internet facing', port</param>
+ public bool DeleteForwardingRule(int externalPort)
+ {
+ if (!CheckAvailability())
+ return false;
+
+ try
+ {
+ SOAPRequest(serviceUrl,
+ $"<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" +
+ "<NewRemoteHost></NewRemoteHost>" +
+ $"<NewExternalPort>{externalPort}</NewExternalPort>" +
+ $"<NewProtocol>UDP</NewProtocol>" +
+ "</u:DeletePortMapping>", "DeletePortMapping");
+ return true;
+ }
+ catch (Exception ex)
+ {
+ // m_peer.LogWarning("UPnP delete forwarding rule failed: " + ex.Message);
+ return false;
+ }
+ }
+
+ /// <summary>
+ /// Retrieve the extern ip using UPnP
+ /// </summary>
+ public IPAddress GetExternalIP()
+ {
+ if (!CheckAvailability())
+ return null;
+ try
+ {
+ XmlDocument xdoc = SOAPRequest(serviceUrl, "<u:GetExternalIPAddress xmlns:u=\"urn:schemas-upnp-org:service:" + serviceName + ":1\">" +
+ "</u:GetExternalIPAddress>", "GetExternalIPAddress");
+ XmlNamespaceManager nsMgr = new XmlNamespaceManager(xdoc.NameTable);
+ nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0");
+ string IP = xdoc.SelectSingleNode("//NewExternalIPAddress/text()", nsMgr).Value;
+ return IPAddress.Parse(IP);
+ }
+ catch (Exception ex)
+ {
+ // m_peer.LogWarning("Failed to get external IP: " + ex.Message);
+ return null;
+ }
+ }
+
+ private XmlDocument SOAPRequest(string url, string soap, string function)
+ {
+ string req =
+"<?xml version=\"1.0\"?>" +
+"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">" +
+$"<s:Body>{soap}</s:Body>" +
+"</s:Envelope>";
+
+ WebRequest r = HttpWebRequest.Create(url);
+ r.Headers.Add("SOAPACTION", $"\"urn:schemas-upnp-org:service:{serviceName}:1#{function}\"");
+ r.ContentType = "text/xml; charset=\"utf-8\"";
+ r.Method = "POST";
+
+ byte[] b = System.Text.Encoding.UTF8.GetBytes(req);
+ r.ContentLength = b.Length;
+ r.GetRequestStream().Write(b, 0, b.Length);
+
+ using (WebResponse wres = r.GetResponse())
+ {
+ XmlDocument resp = new XmlDocument();
+ Stream ress = wres.GetResponseStream();
+ resp.Load(ress);
+ return resp;
+ }
+ }
+
+ public void Dispose()
+ {
+ this.discoveryComplete.Dispose();
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ this.socket.Dispose();
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs
new file mode 100644
index 0000000..74786d8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs
@@ -0,0 +1,39 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Extra internal states for SendOption enumeration when using UDP.
+ /// </summary>
+ public enum UdpSendOption : byte
+ {
+ /// <summary>
+ /// Hello message for initiating communication.
+ /// </summary>
+ Hello = 8,
+
+ /// <summary>
+ /// A single byte of continued existence
+ /// </summary>
+ Ping = 12,
+
+ /// <summary>
+ /// Message for discontinuing communication.
+ /// </summary>
+ Disconnect = 9,
+
+ /// <summary>
+ /// Message acknowledging the receipt of a message.
+ /// </summary>
+ Acknowledgement = 10,
+
+ /// <summary>
+ /// Message that is part of a larger, fragmented message.
+ /// </summary>
+ Fragment = 11,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs
new file mode 100644
index 0000000..13b8d0b
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs
@@ -0,0 +1,157 @@
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ public class BroadcastPacket
+ {
+ public string Data;
+ public DateTime ReceiveTime;
+ public IPEndPoint Sender;
+
+ public BroadcastPacket(string data, IPEndPoint sender)
+ {
+ this.Data = data;
+ this.Sender = sender;
+ this.ReceiveTime = DateTime.Now;
+ }
+
+ public string GetAddress()
+ {
+ return this.Sender.Address.ToString();
+ }
+ }
+
+ public class UdpBroadcastListener : IDisposable
+ {
+ private Socket socket;
+ private EndPoint endpoint;
+ private Action<string> logger;
+
+ private byte[] buffer = new byte[1024];
+
+ private List<BroadcastPacket> packets = new List<BroadcastPacket>();
+
+ public bool Running { get; private set; }
+
+ ///
+ public UdpBroadcastListener(int port, Action<string> logger = null)
+ {
+ this.logger = logger;
+ this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.socket.EnableBroadcast = true;
+ this.socket.MulticastLoopback = false;
+ this.endpoint = new IPEndPoint(IPAddress.Any, port);
+ this.socket.Bind(this.endpoint);
+ }
+
+ ///
+ public void StartListen()
+ {
+ if (this.Running) return;
+ this.Running = true;
+
+ try
+ {
+ EndPoint endpt = new IPEndPoint(IPAddress.Any, 0);
+ this.socket.BeginReceiveFrom(buffer, 0, buffer.Length, SocketFlags.None, ref endpt, this.HandleData, null);
+ }
+ catch (NullReferenceException) { }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ this.Dispose();
+ }
+ }
+
+ private void HandleData(IAsyncResult result)
+ {
+ this.Running = false;
+
+ int numBytes;
+ EndPoint endpt = new IPEndPoint(IPAddress.Any, 0);
+ try
+ {
+ numBytes = this.socket.EndReceiveFrom(result, ref endpt);
+ }
+ catch (NullReferenceException)
+ {
+ // Already disposed
+ return;
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ this.Dispose();
+ return;
+ }
+
+ if (numBytes < 3
+ || buffer[0] != 4 || buffer[1] != 2)
+ {
+ this.StartListen();
+ return;
+ }
+
+ IPEndPoint ipEnd = (IPEndPoint)endpt;
+ string data = UTF8Encoding.UTF8.GetString(buffer, 2, numBytes - 2);
+ int dataHash = data.GetHashCode();
+
+ lock (packets)
+ {
+ bool found = false;
+ for (int i = 0; i < this.packets.Count; ++i)
+ {
+ var pkt = this.packets[i];
+ if (pkt == null || pkt.Data == null)
+ {
+ this.packets.RemoveAt(i);
+ i--;
+ continue;
+ }
+
+ if (pkt.Data.GetHashCode() == dataHash
+ && pkt.Sender.Equals(ipEnd))
+ {
+ this.packets[i].ReceiveTime = DateTime.Now;
+ break;
+ }
+ }
+
+ if (!found)
+ {
+ this.packets.Add(new BroadcastPacket(data, ipEnd));
+ }
+ }
+
+ this.StartListen();
+ }
+
+ ///
+ public BroadcastPacket[] GetPackets()
+ {
+ lock (this.packets)
+ {
+ var output = this.packets.ToArray();
+ this.packets.Clear();
+ return output;
+ }
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (this.socket != null)
+ {
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+ this.socket = null;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs
new file mode 100644
index 0000000..8877f86
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs
@@ -0,0 +1,127 @@
+using Hazel.UPnP;
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+
+namespace Hazel.Udp
+{
+ public class UdpBroadcaster : IDisposable
+ {
+ private SocketBroadcast[] socketBroadcasts;
+ private byte[] data;
+ private Action<string> logger;
+
+ ///
+ public UdpBroadcaster(int port, Action<string> logger = null)
+ {
+ this.logger = logger;
+
+ var addresses = NetUtility.GetAddressesFromNetworkInterfaces(AddressFamily.InterNetwork);
+ this.socketBroadcasts = new SocketBroadcast[addresses.Count > 0 ? addresses.Count : 1];
+
+ int count = 0;
+ foreach (var addressInformation in addresses)
+ {
+ Socket socket = CreateSocket(new IPEndPoint(addressInformation.Address, 0));
+ IPAddress broadcast = NetUtility.GetBroadcastAddress(addressInformation);
+
+ this.socketBroadcasts[count] = new SocketBroadcast(socket, new IPEndPoint(broadcast, port));
+ count++;
+ }
+ if (count == 0)
+ {
+ Socket socket = CreateSocket(new IPEndPoint(IPAddress.Any, 0));
+
+ this.socketBroadcasts[0] = new SocketBroadcast(socket, new IPEndPoint(IPAddress.Broadcast, port));
+ }
+ }
+
+ private static Socket CreateSocket(IPEndPoint endPoint)
+ {
+ var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ socket.EnableBroadcast = true;
+ socket.MulticastLoopback = false;
+ socket.Bind(endPoint);
+
+ return socket;
+ }
+
+ ///
+ public void SetData(string data)
+ {
+ int len = UTF8Encoding.UTF8.GetByteCount(data);
+ this.data = new byte[len + 2];
+ this.data[0] = 4;
+ this.data[1] = 2;
+
+ UTF8Encoding.UTF8.GetBytes(data, 0, data.Length, this.data, 2);
+ }
+
+ ///
+ public void Broadcast()
+ {
+ if (this.data == null)
+ {
+ return;
+ }
+
+ foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts)
+ {
+ try
+ {
+ Socket socket = socketBroadcast.Socket;
+ socket.BeginSendTo(data, 0, data.Length, SocketFlags.None, socketBroadcast.Broadcast, this.FinishSendTo, socket);
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ }
+ }
+ }
+
+ private void FinishSendTo(IAsyncResult evt)
+ {
+ try
+ {
+ Socket socket = (Socket)evt.AsyncState;
+ socket.EndSendTo(evt);
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ }
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (this.socketBroadcasts != null)
+ {
+ foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts)
+ {
+ Socket socket = socketBroadcast.Socket;
+ if (socket != null)
+ {
+ try { socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { socket.Close(); } catch { }
+ try { socket.Dispose(); } catch { }
+ }
+ }
+ Array.Clear(this.socketBroadcasts, 0, this.socketBroadcasts.Length);
+ }
+ }
+
+ private struct SocketBroadcast
+ {
+ public Socket Socket;
+ public IPEndPoint Broadcast;
+
+ public SocketBroadcast(Socket socket, IPEndPoint broadcast)
+ {
+ Socket = socket;
+ Broadcast = broadcast;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs
new file mode 100644
index 0000000..f6da329
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs
@@ -0,0 +1,364 @@
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a client's connection to a server that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ public sealed class UdpClientConnection : UdpConnection
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ /// <summary>
+ /// The socket we're connected via.
+ /// </summary>
+ private Socket socket;
+
+ /// <summary>
+ /// Reset event that is triggered when the connection is marked Connected.
+ /// </summary>
+ private ManualResetEvent connectWaitLock = new ManualResetEvent(false);
+
+ private Timer reliablePacketTimer;
+
+#if DEBUG
+ public event Action<byte[], int> DataSentRaw;
+ public event Action<byte[], int> DataReceivedRaw;
+#endif
+
+ /// <summary>
+ /// Creates a new UdpClientConnection.
+ /// </summary>
+ /// <param name="remoteEndPoint">A <see cref="NetworkEndPoint"/> to connect to.</param>
+ public UdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger)
+ {
+ this.EndPoint = remoteEndPoint;
+ this.IPMode = ipMode;
+
+ this.socket = CreateSocket(ipMode);
+
+ reliablePacketTimer = new Timer(ManageReliablePacketsInternal, null, 100, Timeout.Infinite);
+ this.InitializeKeepAliveTimer();
+ }
+
+ ~UdpClientConnection()
+ {
+ this.Dispose(false);
+ }
+
+ private void ManageReliablePacketsInternal(object state)
+ {
+ base.ManageReliablePackets();
+ try
+ {
+ reliablePacketTimer.Change(100, Timeout.Infinite);
+ }
+ catch { }
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+#if DEBUG
+ if (TestLagMs > 0)
+ {
+ ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); });
+ }
+ else
+#endif
+ {
+ WriteBytesToConnectionReal(bytes, length);
+ }
+ }
+
+ private void WriteBytesToConnectionReal(byte[] bytes, int length)
+ {
+#if DEBUG
+ DataSentRaw?.Invoke(bytes, length);
+#endif
+
+ try
+ {
+ this.Statistics.LogPacketSend(length);
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint,
+ HandleSendTo,
+ null);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ private void HandleSendTo(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ /// <inheritdoc />
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.ConnectAsync(bytes);
+
+ //Wait till hello packet is acknowledged and the state is set to Connected
+ bool timedOut = !WaitOnConnect(timeout);
+
+ //If we timed out raise an exception
+ if (timedOut)
+ {
+ Dispose();
+ throw new HazelException("Connection attempt timed out.");
+ }
+ }
+
+ /// <inheritdoc />
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connecting;
+
+ try
+ {
+ if (IPMode == IPMode.IPv4)
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ else
+ socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
+ }
+ catch (SocketException e)
+ {
+ this.State = ConnectionState.NotConnected;
+ throw new HazelException("A SocketException occurred while binding to the port.", e);
+ }
+
+ try
+ {
+ StartListeningForData();
+ }
+ catch (ObjectDisposedException)
+ {
+ // If the socket's been disposed then we can just end there but make sure we're in NotConnected state.
+ // If we end up here I'm really lost...
+ this.State = ConnectionState.NotConnected;
+ return;
+ }
+ catch (SocketException e)
+ {
+ Dispose();
+ throw new HazelException("A SocketException occurred while initiating a receive operation.", e);
+ }
+
+ // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present)
+ // When acknowledged set the state to connected
+ SendHello(bytes, () =>
+ {
+ this.State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ });
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ void StartListeningForData()
+ {
+#if DEBUG
+ if (this.TestLagMs > 0)
+ {
+ Thread.Sleep(this.TestLagMs);
+ }
+#endif
+
+ var msg = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ socket.BeginReceive(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ReadCallback, msg);
+ }
+ catch
+ {
+ msg.Recycle();
+ this.Dispose();
+ }
+ }
+
+ protected override void SetState(ConnectionState state)
+ {
+ try
+ {
+ // If the server disconnects you during the hello
+ // you can go straight from Connecting to NotConnected.
+ if (state == ConnectionState.Connected
+ || state == ConnectionState.NotConnected)
+ {
+ connectWaitLock.Set();
+ }
+ else
+ {
+ connectWaitLock.Reset();
+ }
+ }
+ catch (ObjectDisposedException)
+ {
+ }
+ }
+
+ /// <summary>
+ /// Blocks until the Connection is connected.
+ /// </summary>
+ /// <param name="timeout">The number of milliseconds to wait before timing out.</param>
+ public bool WaitOnConnect(int timeout)
+ {
+ return connectWaitLock.WaitOne(timeout);
+ }
+
+ /// <summary>
+ /// Called when data has been received by the socket.
+ /// </summary>
+ /// <param name="result">The asyncronous operation's result.</param>
+ void ReadCallback(IAsyncResult result)
+ {
+ var msg = (MessageReader)result.AsyncState;
+
+ try
+ {
+ msg.Length = socket.EndReceive(result);
+ }
+ catch (SocketException e)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message);
+ return;
+ }
+ catch (Exception)
+ {
+ msg.Recycle();
+ return;
+ }
+
+ //Exit if no bytes read, we've failed.
+ if (msg.Length == 0)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes");
+ return;
+ }
+
+ //Begin receiving again
+ try
+ {
+ StartListeningForData();
+ }
+ catch (SocketException e)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message);
+ }
+ catch (ObjectDisposedException)
+ {
+ //If the socket's been disposed then we can just end there.
+ return;
+ }
+
+#if DEBUG
+ if (this.TestDropRate > 0)
+ {
+ if ((this.testDropCount++ % this.TestDropRate) == 0)
+ {
+ return;
+ }
+ }
+
+ DataReceivedRaw?.Invoke(msg.Buffer, msg.Length);
+#endif
+ HandleReceive(msg, msg.Length);
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// You may include optional disconnect data. The SendOption must be unreliable.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state == ConnectionState.NotConnected) return false;
+ this.State = ConnectionState.NotConnected; // Use the property so we release the state lock
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ bytes.Length,
+ SocketFlags.None,
+ EndPoint);
+ }
+ catch { }
+
+ return true;
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.reliablePacketTimer.Dispose();
+ this.connectWaitLock.Dispose();
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs
new file mode 100644
index 0000000..75b4f1d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs
@@ -0,0 +1,167 @@
+using System;
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ partial class UdpConnection
+ {
+
+ /// <summary>
+ /// Class to hold packet data
+ /// </summary>
+ public class PingPacket : IRecyclable
+ {
+ private static readonly ObjectPool<PingPacket> PacketPool = new ObjectPool<PingPacket>(() => new PingPacket());
+
+ public readonly Stopwatch Stopwatch = new Stopwatch();
+
+ internal static PingPacket GetObject()
+ {
+ return PacketPool.GetObject();
+ }
+
+ public void Recycle()
+ {
+ Stopwatch.Stop();
+ PacketPool.PutObject(this);
+ }
+ }
+
+ internal ConcurrentDictionary<ushort, PingPacket> activePingPackets = new ConcurrentDictionary<ushort, PingPacket>();
+
+ /// <summary>
+ /// The interval from data being received or transmitted to a keepalive packet being sent in milliseconds.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Keepalive packets serve to close connections when an endpoint abruptly disconnects and to ensure than any
+ /// NAT devices do not close their translation for our argument. By ensuring there is regular contact the
+ /// connection can detect and prevent these issues.
+ /// </para>
+ /// <para>
+ /// The default value is 10 seconds, set to System.Threading.Timeout.Infinite to disable keepalive packets.
+ /// </para>
+ /// </remarks>
+ public int KeepAliveInterval
+ {
+ get
+ {
+ return keepAliveInterval;
+ }
+
+ set
+ {
+ keepAliveInterval = value;
+ ResetKeepAliveTimer();
+ }
+ }
+ private int keepAliveInterval = 1500;
+
+ public int MissingPingsUntilDisconnect { get; set; } = 6;
+ private volatile int pingsSinceAck = 0;
+
+ /// <summary>
+ /// The timer creating keepalive pulses.
+ /// </summary>
+ private Timer keepAliveTimer;
+
+ /// <summary>
+ /// Starts the keepalive timer.
+ /// </summary>
+ protected void InitializeKeepAliveTimer()
+ {
+ keepAliveTimer = new Timer(
+ HandleKeepAlive,
+ null,
+ keepAliveInterval,
+ keepAliveInterval
+ );
+ }
+
+ private void HandleKeepAlive(object state)
+ {
+ if (this.State != ConnectionState.Connected) return;
+
+ if (this.pingsSinceAck >= this.MissingPingsUntilDisconnect)
+ {
+ this.DisposeKeepAliveTimer();
+ this.DisconnectInternal(HazelInternalErrors.PingsWithoutResponse, $"Sent {this.pingsSinceAck} pings that remote has not responded to.");
+ return;
+ }
+
+ try
+ {
+ this.pingsSinceAck++;
+ SendPing();
+ }
+ catch
+ {
+ }
+ }
+
+ // Pings are special, quasi-reliable packets.
+ // We send them to trigger responses that validate our connection is alive
+ // An unacked ping should never be the sole cause of a disconnect.
+ // Rather, the responses will reset our pingsSinceAck, enough unacked
+ // pings should cause a disconnect.
+ private void SendPing()
+ {
+ ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated);
+
+ byte[] bytes = new byte[3];
+ bytes[0] = (byte)UdpSendOption.Ping;
+ bytes[1] = (byte)(id >> 8);
+ bytes[2] = (byte)id;
+
+ PingPacket pkt;
+ if (!this.activePingPackets.TryGetValue(id, out pkt))
+ {
+ pkt = PingPacket.GetObject();
+ if (!this.activePingPackets.TryAdd(id, pkt))
+ {
+ throw new Exception("This shouldn't be possible");
+ }
+ }
+
+ pkt.Stopwatch.Restart();
+
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogReliableSend(0);
+ }
+
+ /// <summary>
+ /// Resets the keepalive timer to zero.
+ /// </summary>
+ protected void ResetKeepAliveTimer()
+ {
+ try
+ {
+ keepAliveTimer?.Change(keepAliveInterval, keepAliveInterval);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Disposes of the keep alive timer.
+ /// </summary>
+ private void DisposeKeepAliveTimer()
+ {
+ if (this.keepAliveTimer != null)
+ {
+ this.keepAliveTimer.Dispose();
+ }
+
+ foreach (var kvp in activePingPackets)
+ {
+ if (this.activePingPackets.TryRemove(kvp.Key, out var pkt))
+ {
+ pkt.Recycle();
+ }
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs
new file mode 100644
index 0000000..bed4738
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs
@@ -0,0 +1,490 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ partial class UdpConnection
+ {
+ private const int MinResendDelayMs = 50;
+ private const int MaxInitialResendDelayMs = 300;
+ private const int MaxAdditionalResendDelayMs = 1000;
+
+ public readonly ObjectPool<Packet> PacketPool;
+
+ /// <summary>
+ /// The starting timeout, in miliseconds, at which data will be resent.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// For reliable delivery data is resent at specified intervals unless an acknowledgement is received from the
+ /// receiving device. The ResendTimeout specifies the interval between the packets being resent, each time a packet
+ /// is resent the interval is increased for that packet until the duration exceeds the <see cref="DisconnectTimeoutMs"/> value.
+ /// </para>
+ /// <para>
+ /// Setting this to its default of 0 will mean the timeout is 2 times the value of the average ping, usually
+ /// resulting in a more dynamic resend that responds to endpoints on slower or faster connections.
+ /// </para>
+ /// </remarks>
+ public volatile int ResendTimeoutMs = 0;
+
+ /// <summary>
+ /// Max number of times to resend. 0 == no limit
+ /// </summary>
+ public volatile int ResendLimit = 0;
+
+ /// <summary>
+ /// A compounding multiplier to back off resend timeout.
+ /// Applied to ping before first timeout when ResendTimeout == 0.
+ /// </summary>
+ public volatile float ResendPingMultiplier = 2;
+
+ /// <summary>
+ /// Holds the last ID allocated.
+ /// </summary>
+ private int lastIDAllocated = -1;
+
+ /// <summary>
+ /// The packets of data that have been transmitted reliably and not acknowledged.
+ /// </summary>
+ internal ConcurrentDictionary<ushort, Packet> reliableDataPacketsSent = new ConcurrentDictionary<ushort, Packet>();
+
+ /// <summary>
+ /// Packet ids that have not been received, but are expected.
+ /// </summary>
+ private HashSet<ushort> reliableDataPacketsMissing = new HashSet<ushort>();
+
+ /// <summary>
+ /// The packet id that was received last.
+ /// </summary>
+ protected volatile ushort reliableReceiveLast = ushort.MaxValue;
+
+ private object PingLock = new object();
+
+ /// <summary>
+ /// Returns the average ping to this endpoint.
+ /// </summary>
+ /// <remarks>
+ /// This returns the average ping for a one-way trip as calculated from the reliable packets that have been sent
+ /// and acknowledged by the endpoint.
+ /// </remarks>
+ private float _pingMs = 500;
+
+ /// <summary>
+ /// The maximum times a message should be resent before marking the endpoint as disconnected.
+ /// </summary>
+ /// <remarks>
+ /// Reliable packets will be resent at an interval defined in <see cref="ResendTimeoutMs"/> for the number of times
+ /// specified here. Once a packet has been retransmitted this number of times and has not been acknowledged the
+ /// connection will be marked as disconnected and the <see cref="Connection.Disconnected">Disconnected</see> event
+ /// will be invoked.
+ /// </remarks>
+ public volatile int DisconnectTimeoutMs = 5000;
+
+ /// <summary>
+ /// Class to hold packet data
+ /// </summary>
+ public class Packet : IRecyclable
+ {
+ public ushort Id;
+ private byte[] Data;
+ private readonly UdpConnection Connection;
+ private int Length;
+
+ public int NextTimeoutMs;
+ public volatile bool Acknowledged;
+
+ public Action AckCallback;
+
+ public int Retransmissions;
+ public Stopwatch Stopwatch = new Stopwatch();
+
+ internal Packet(UdpConnection connection)
+ {
+ this.Connection = connection;
+ }
+
+ internal void Set(ushort id, byte[] data, int length, int timeout, Action ackCallback)
+ {
+ this.Id = id;
+ this.Data = data;
+ this.Length = length;
+
+ this.Acknowledged = false;
+ this.NextTimeoutMs = timeout;
+ this.AckCallback = ackCallback;
+ this.Retransmissions = 0;
+
+ this.Stopwatch.Restart();
+ }
+
+ // Packets resent
+ public int Resend()
+ {
+ var connection = this.Connection;
+ if (!this.Acknowledged && connection != null)
+ {
+ long lifetimeMs = this.Stopwatch.ElapsedMilliseconds;
+ if (lifetimeMs >= connection.DisconnectTimeoutMs)
+ {
+ if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self))
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {lifetimeMs}ms ({self.Retransmissions} resends)");
+
+ self.Recycle();
+ }
+
+ return 0;
+ }
+
+ if (lifetimeMs >= this.NextTimeoutMs)
+ {
+ ++this.Retransmissions;
+ if (connection.ResendLimit != 0
+ && this.Retransmissions > connection.ResendLimit)
+ {
+ if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self))
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {self.Retransmissions} resends ({lifetimeMs}ms)");
+
+ self.Recycle();
+ }
+
+ return 0;
+ }
+
+ this.NextTimeoutMs += (int)Math.Min(this.NextTimeoutMs * connection.ResendPingMultiplier, MaxAdditionalResendDelayMs);
+ try
+ {
+ connection.WriteBytesToConnection(this.Data, this.Length);
+ connection.Statistics.LogMessageResent();
+ return 1;
+ }
+ catch (InvalidOperationException)
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ConnectionDisconnected, "Could not resend data as connection is no longer connected");
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ /// <summary>
+ /// Returns this object back to the object pool from whence it came.
+ /// </summary>
+ public void Recycle()
+ {
+ this.Acknowledged = true;
+
+ this.Connection.PacketPool.PutObject(this);
+ }
+ }
+
+ internal int ManageReliablePackets()
+ {
+ int output = 0;
+ if (this.reliableDataPacketsSent.Count > 0)
+ {
+ foreach (var kvp in this.reliableDataPacketsSent)
+ {
+ Packet pkt = kvp.Value;
+
+ try
+ {
+ output += pkt.Resend();
+ }
+ catch { }
+ }
+ }
+
+ return output;
+ }
+
+ /// <summary>
+ /// Adds a 2 byte ID to the packet at offset and stores the packet reference for retransmission.
+ /// </summary>
+ /// <param name="buffer">The buffer to attach to.</param>
+ /// <param name="offset">The offset to attach at.</param>
+ /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param>
+ protected void AttachReliableID(byte[] buffer, int offset, Action ackCallback = null)
+ {
+ ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated);
+
+ buffer[offset] = (byte)(id >> 8);
+ buffer[offset + 1] = (byte)id;
+
+ int resendDelayMs = this.ResendTimeoutMs;
+ if (resendDelayMs <= 0)
+ {
+ resendDelayMs = (_pingMs * this.ResendPingMultiplier).ClampToInt(MinResendDelayMs, MaxInitialResendDelayMs);
+ }
+
+ Packet packet = this.PacketPool.GetObject();
+ packet.Set(
+ id,
+ buffer,
+ buffer.Length,
+ resendDelayMs,
+ ackCallback);
+
+ if (!reliableDataPacketsSent.TryAdd(id, packet))
+ {
+ throw new Exception("That shouldn't be possible");
+ }
+ }
+
+ public static int ClampToInt(float value, int min, int max)
+ {
+ if (value < min) return min;
+ if (value > max) return max;
+ return (int)value;
+ }
+
+ /// <summary>
+ /// Sends the bytes reliably and stores the send.
+ /// </summary>
+ /// <param name="sendOption"></param>
+ /// <param name="data">The byte array to write to.</param>
+ /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param>
+ private void ReliableSend(byte sendOption, byte[] data, Action ackCallback = null)
+ {
+ //Inform keepalive not to send for a while
+ ResetKeepAliveTimer();
+
+ byte[] bytes = new byte[data.Length + 3];
+
+ //Add message type
+ bytes[0] = sendOption;
+
+ //Add reliable ID
+ AttachReliableID(bytes, 1, ackCallback);
+
+ //Copy data into new array
+ Buffer.BlockCopy(data, 0, bytes, bytes.Length - data.Length, data.Length);
+
+ //Write to connection
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogReliableSend(data.Length);
+ }
+
+ /// <summary>
+ /// Handles a reliable message being received and invokes the data event.
+ /// </summary>
+ /// <param name="message">The buffer received.</param>
+ private void ReliableMessageReceive(MessageReader message, int bytesReceived)
+ {
+ ushort id;
+ if (ProcessReliableReceive(message.Buffer, 1, out id))
+ {
+ InvokeDataReceived(SendOption.Reliable, message, 3, bytesReceived);
+ }
+ else
+ {
+ message.Recycle();
+ }
+
+ Statistics.LogReliableReceive(message.Length - 3, message.Length);
+ }
+
+ /// <summary>
+ /// Handles receives from reliable packets.
+ /// </summary>
+ /// <param name="bytes">The buffer containing the data.</param>
+ /// <param name="offset">The offset of the reliable header.</param>
+ /// <returns>Whether the packet was a new packet or not.</returns>
+ private bool ProcessReliableReceive(byte[] bytes, int offset, out ushort id)
+ {
+ byte b1 = bytes[offset];
+ byte b2 = bytes[offset + 1];
+
+ //Get the ID form the packet
+ id = (ushort)((b1 << 8) + b2);
+
+ /*
+ * It gets a little complicated here (note the fact I'm actually using a multiline comment for once...)
+ *
+ * In a simple world if our data is greater than the last reliable packet received (reliableReceiveLast)
+ * then it is guaranteed to be a new packet, if it's not we can see if we are missing that packet (lookup
+ * in reliableDataPacketsMissing).
+ *
+ * --------rrl############# (1)
+ *
+ * (where --- are packets received already and #### are packets that will be counted as new)
+ *
+ * Unfortunately if id becomes greater than 65535 it will loop back to zero so we will add a pointer that
+ * specifies any packets with an id behind it are also new (overwritePointer).
+ *
+ * ####op----------rrl##### (2)
+ *
+ * ------rll#########op---- (3)
+ *
+ * Anything behind than the reliableReceiveLast pointer (but greater than the overwritePointer is either a
+ * missing packet or something we've already received so when we change the pointers we need to make sure
+ * we keep note of what hasn't been received yet (reliableDataPacketsMissing).
+ *
+ * So...
+ */
+
+ bool result = true;
+
+ lock (reliableDataPacketsMissing)
+ {
+ //Calculate overwritePointer
+ ushort overwritePointer = (ushort)(reliableReceiveLast - 32768);
+
+ //Calculate if it is a new packet by examining if it is within the range
+ bool isNew;
+ if (overwritePointer < reliableReceiveLast)
+ isNew = id > reliableReceiveLast || id <= overwritePointer; //Figure (2)
+ else
+ isNew = id > reliableReceiveLast && id <= overwritePointer; //Figure (3)
+
+ //If it's new or we've not received anything yet
+ if (isNew)
+ {
+ // Mark items between the most recent receive and the id received as missing
+ if (id > reliableReceiveLast)
+ {
+ for (ushort i = (ushort)(reliableReceiveLast + 1); i < id; i++)
+ {
+ reliableDataPacketsMissing.Add(i);
+ }
+ }
+ else
+ {
+ int cnt = (ushort.MaxValue - reliableReceiveLast) + id;
+ for (ushort i = 1; i <= cnt; ++i)
+ {
+ reliableDataPacketsMissing.Add((ushort)(i + reliableReceiveLast));
+ }
+ }
+
+ //Update the most recently received
+ reliableReceiveLast = id;
+ }
+
+ //Else it could be a missing packet
+ else
+ {
+ //See if we're missing it, else this packet is a duplicate as so we return false
+ if (!reliableDataPacketsMissing.Remove(id))
+ {
+ result = false;
+ }
+ }
+ }
+
+ // Send an acknowledgement
+ SendAck(id);
+
+ return result;
+ }
+
+ /// <summary>
+ /// Handles acknowledgement packets to us.
+ /// </summary>
+ /// <param name="bytes">The buffer containing the data.</param>
+ private void AcknowledgementMessageReceive(byte[] bytes, int bytesReceived)
+ {
+ this.pingsSinceAck = 0;
+
+ ushort id = (ushort)((bytes[1] << 8) + bytes[2]);
+ AcknowledgeMessageId(id);
+
+ if (bytesReceived == 4)
+ {
+ byte recentPackets = bytes[3];
+ for (int i = 1; i <= 8; ++i)
+ {
+ if ((recentPackets & 1) != 0)
+ {
+ AcknowledgeMessageId((ushort)(id - i));
+ }
+
+ recentPackets >>= 1;
+ }
+ }
+
+ Statistics.LogAcknowledgementReceive(bytesReceived);
+ }
+
+ private void AcknowledgeMessageId(ushort id)
+ {
+ // Dispose of timer and remove from dictionary
+ if (reliableDataPacketsSent.TryRemove(id, out Packet packet))
+ {
+ this.Statistics.LogReliablePacketAcknowledged();
+ float rt = packet.Stopwatch.ElapsedMilliseconds;
+
+ packet.AckCallback?.Invoke();
+ packet.Recycle();
+
+ lock (PingLock)
+ {
+ this._pingMs = this._pingMs * .7f + rt * .3f;
+ }
+ }
+ else if (this.activePingPackets.TryRemove(id, out PingPacket pingPkt))
+ {
+ this.Statistics.LogReliablePacketAcknowledged();
+ float rt = pingPkt.Stopwatch.ElapsedMilliseconds;
+
+ pingPkt.Recycle();
+
+ lock (PingLock)
+ {
+ this._pingMs = this._pingMs * .7f + rt * .3f;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Sends an acknowledgement for a packet given its identification bytes.
+ /// </summary>
+ /// <param name="byte1">The first identification byte.</param>
+ /// <param name="byte2">The second identification byte.</param>
+ private void SendAck(ushort id)
+ {
+ byte recentPackets = 0;
+ lock (this.reliableDataPacketsMissing)
+ {
+ for (int i = 1; i <= 8; ++i)
+ {
+ if (!this.reliableDataPacketsMissing.Contains((ushort)(id - i)))
+ {
+ recentPackets |= (byte)(1 << (i - 1));
+ }
+ }
+ }
+
+ byte[] bytes = new byte[]
+ {
+ (byte)UdpSendOption.Acknowledgement,
+ (byte)(id >> 8),
+ (byte)(id >> 0),
+ recentPackets
+ };
+
+ try
+ {
+ WriteBytesToConnection(bytes, bytes.Length);
+ }
+ catch (InvalidOperationException) { }
+ }
+
+ private void DisposeReliablePackets()
+ {
+ foreach (var kvp in reliableDataPacketsSent)
+ {
+ if (this.reliableDataPacketsSent.TryRemove(kvp.Key, out var pkt))
+ {
+ pkt.Recycle();
+ }
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs
new file mode 100644
index 0000000..e64576a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs
@@ -0,0 +1,259 @@
+using System;
+using System.Net.Sockets;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a connection that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc />
+ public abstract partial class UdpConnection : NetworkConnection
+ {
+ public static readonly byte[] EmptyDisconnectBytes = new byte[] { (byte)UdpSendOption.Disconnect };
+
+ public override float AveragePingMs => this._pingMs;
+ protected readonly ILogger logger;
+
+
+ public UdpConnection(ILogger logger) : base()
+ {
+ this.logger = logger;
+ this.PacketPool = new ObjectPool<Packet>(() => new Packet(this));
+ }
+
+ internal static Socket CreateSocket(IPMode ipMode)
+ {
+ Socket socket;
+ if (ipMode == IPMode.IPv4)
+ {
+ socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ }
+ else
+ {
+ if (!Socket.OSSupportsIPv6)
+ throw new InvalidOperationException("IPV6 not supported!");
+
+ socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp);
+ socket.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, false);
+ }
+
+ try
+ {
+ socket.DontFragment = false;
+ }
+ catch { }
+
+ try
+ {
+ const int SIO_UDP_CONNRESET = -1744830452;
+ socket.IOControl(SIO_UDP_CONNRESET, new byte[1], null);
+ }
+ catch { } // Only necessary on Windows
+
+ return socket;
+ }
+
+ /// <summary>
+ /// Writes the given bytes to the connection.
+ /// </summary>
+ /// <param name="bytes">The bytes to write.</param>
+ protected abstract void WriteBytesToConnection(byte[] bytes, int length);
+
+ /// <inheritdoc/>
+ public override SendErrors Send(MessageWriter msg)
+ {
+ if (this._state != ConnectionState.Connected)
+ {
+ return SendErrors.Disconnected;
+ }
+
+ try
+ {
+ byte[] buffer = new byte[msg.Length];
+ Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length);
+
+ switch (msg.SendOption)
+ {
+ case SendOption.Reliable:
+ ResetKeepAliveTimer();
+
+ AttachReliableID(buffer, 1);
+ WriteBytesToConnection(buffer, buffer.Length);
+ Statistics.LogReliableSend(buffer.Length - 3);
+ break;
+
+ default:
+ WriteBytesToConnection(buffer, buffer.Length);
+ Statistics.LogUnreliableSend(buffer.Length - 1);
+ break;
+ }
+ }
+ catch (Exception e)
+ {
+ this.logger?.WriteError("Unknown exception while sending: " + e);
+ return SendErrors.Unknown;
+ }
+
+ return SendErrors.None;
+ }
+
+ /// <summary>
+ /// Handles the reliable/fragmented sending from this connection.
+ /// </summary>
+ /// <param name="data">The data being sent.</param>
+ /// <param name="sendOption">The <see cref="SendOption"/> specified as its byte value.</param>
+ /// <param name="ackCallback">The callback to invoke when this packet is acknowledged.</param>
+ /// <returns>The bytes that should actually be sent.</returns>
+ protected virtual void HandleSend(byte[] data, byte sendOption, Action ackCallback = null)
+ {
+ switch (sendOption)
+ {
+ case (byte)UdpSendOption.Ping:
+ case (byte)SendOption.Reliable:
+ case (byte)UdpSendOption.Hello:
+ ReliableSend(sendOption, data, ackCallback);
+ break;
+
+ //Treat all else as unreliable
+ default:
+ UnreliableSend(sendOption, data);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Handles the receiving of data.
+ /// </summary>
+ /// <param name="message">The buffer containing the bytes received.</param>
+ protected internal virtual void HandleReceive(MessageReader message, int bytesReceived)
+ {
+ ushort id;
+ switch (message.Buffer[0])
+ {
+ //Handle reliable receives
+ case (byte)SendOption.Reliable:
+ ReliableMessageReceive(message, bytesReceived);
+ break;
+
+ //Handle acknowledgments
+ case (byte)UdpSendOption.Acknowledgement:
+ AcknowledgementMessageReceive(message.Buffer, bytesReceived);
+ message.Recycle();
+ break;
+
+ //We need to acknowledge hello and ping messages but dont want to invoke any events!
+ case (byte)UdpSendOption.Ping:
+ ProcessReliableReceive(message.Buffer, 1, out id);
+ Statistics.LogHelloReceive(bytesReceived);
+ message.Recycle();
+ break;
+ case (byte)UdpSendOption.Hello:
+ ProcessReliableReceive(message.Buffer, 1, out id);
+ Statistics.LogHelloReceive(bytesReceived);
+ message.Recycle();
+ break;
+
+ case (byte)UdpSendOption.Disconnect:
+ message.Offset = 1;
+ message.Position = 0;
+ DisconnectRemote("The remote sent a disconnect request", message);
+ message.Recycle();
+ break;
+
+ case (byte)SendOption.None:
+ InvokeDataReceived(SendOption.None, message, 1, bytesReceived);
+ Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived);
+ break;
+
+ // Treat everything else as garbage
+ default:
+ message.Recycle();
+
+ // TODO: A new stat for unused data
+ Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Sends bytes using the unreliable UDP protocol.
+ /// </summary>
+ /// <param name="sendOption">The SendOption to attach.</param>
+ /// <param name="data">The data.</param>
+ void UnreliableSend(byte sendOption, byte[] data)
+ {
+ this.UnreliableSend(sendOption, data, 0, data.Length);
+ }
+
+ /// <summary>
+ /// Sends bytes using the unreliable UDP protocol.
+ /// </summary>
+ /// <param name="data">The data.</param>
+ /// <param name="sendOption">The SendOption to attach.</param>
+ /// <param name="offset"></param>
+ /// <param name="length"></param>
+ void UnreliableSend(byte sendOption, byte[] data, int offset, int length)
+ {
+ byte[] bytes = new byte[length + 1];
+
+ //Add message type
+ bytes[0] = sendOption;
+
+ //Copy data into new array
+ Buffer.BlockCopy(data, offset, bytes, bytes.Length - length, length);
+
+ //Write to connection
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogUnreliableSend(length);
+ }
+
+ /// <summary>
+ /// Helper method to invoke the data received event.
+ /// </summary>
+ /// <param name="sendOption">The send option the message was received with.</param>
+ /// <param name="buffer">The buffer received.</param>
+ /// <param name="dataOffset">The offset of data in the buffer.</param>
+ void InvokeDataReceived(SendOption sendOption, MessageReader buffer, int dataOffset, int bytesReceived)
+ {
+ buffer.Offset = dataOffset;
+ buffer.Length = bytesReceived - dataOffset;
+ buffer.Position = 0;
+
+ InvokeDataReceived(buffer, sendOption);
+ }
+
+ /// <summary>
+ /// Sends a hello packet to the remote endpoint.
+ /// </summary>
+ /// <param name="acknowledgeCallback">The callback to invoke when the hello packet is acknowledged.</param>
+ protected void SendHello(byte[] bytes, Action acknowledgeCallback)
+ {
+ //First byte of handshake is version indicator so add data after
+ byte[] actualBytes;
+ if (bytes == null)
+ {
+ actualBytes = new byte[1];
+ }
+ else
+ {
+ actualBytes = new byte[bytes.Length + 1];
+ Buffer.BlockCopy(bytes, 0, actualBytes, 1, bytes.Length);
+ }
+
+ HandleSend(actualBytes, (byte)UdpSendOption.Hello, acknowledgeCallback);
+ }
+
+ /// <inheritdoc/>
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ DisposeKeepAliveTimer();
+ DisposeReliablePackets();
+ }
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs
new file mode 100644
index 0000000..c017a0f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs
@@ -0,0 +1,339 @@
+using System;
+using System.Collections.Concurrent;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Listens for new UDP connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class UdpConnectionListener : NetworkConnectionListener
+ {
+ private const int SendReceiveBufferSize = 1024 * 1024;
+ private const int BufferSize = ushort.MaxValue;
+
+ private Socket socket;
+ private ILogger Logger;
+ private Timer reliablePacketTimer;
+
+ private ConcurrentDictionary<EndPoint, UdpServerConnection> allConnections = new ConcurrentDictionary<EndPoint, UdpServerConnection>();
+
+ public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count;
+ public override int ConnectionCount { get { return this.allConnections.Count; } }
+ public override int ReceiveQueueLength => throw new NotImplementedException();
+ public override int SendQueueLength => throw new NotImplementedException();
+
+ /// <summary>
+ /// Creates a new UdpConnectionListener for the given <see cref="IPAddress"/>, port and <see cref="IPMode"/>.
+ /// </summary>
+ /// <param name="endPoint">The endpoint to listen on.</param>
+ public UdpConnectionListener(IPEndPoint endPoint, IPMode ipMode = IPMode.IPv4, ILogger logger = null)
+ {
+ this.Logger = logger;
+ this.EndPoint = endPoint;
+ this.IPMode = ipMode;
+
+ this.socket = UdpConnection.CreateSocket(this.IPMode);
+
+ socket.ReceiveBufferSize = SendReceiveBufferSize;
+ socket.SendBufferSize = SendReceiveBufferSize;
+
+ reliablePacketTimer = new Timer(ManageReliablePackets, null, 100, Timeout.Infinite);
+ }
+
+ ~UdpConnectionListener()
+ {
+ this.Dispose(false);
+ }
+
+ private void ManageReliablePackets(object state)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ var sock = kvp.Value;
+ sock.ManageReliablePackets();
+ }
+
+ try
+ {
+ this.reliablePacketTimer.Change(100, Timeout.Infinite);
+ }
+ catch { }
+ }
+
+ /// <inheritdoc />
+ public override void Start()
+ {
+ try
+ {
+ socket.Bind(EndPoint);
+ }
+ catch (SocketException e)
+ {
+ throw new HazelException("Could not start listening as a SocketException occurred", e);
+ }
+
+ StartListeningForData();
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ private void StartListeningForData()
+ {
+ EndPoint remoteEP = EndPoint;
+
+ MessageReader message = null;
+ try
+ {
+ message = MessageReader.GetSized(this.ReceiveBufferSize);
+ socket.BeginReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP, ReadCallback, message);
+ }
+ catch (SocketException sx)
+ {
+ message?.Recycle();
+
+ this.Logger?.WriteError("Socket Ex in StartListening: " + sx.Message);
+
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+ catch (Exception ex)
+ {
+ message.Recycle();
+ this.Logger?.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+ }
+
+ void ReadCallback(IAsyncResult result)
+ {
+ var message = (MessageReader)result.AsyncState;
+ int bytesReceived;
+ EndPoint remoteEndPoint = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port);
+
+ //End the receive operation
+ try
+ {
+ bytesReceived = socket.EndReceiveFrom(result, ref remoteEndPoint);
+
+ message.Offset = 0;
+ message.Length = bytesReceived;
+ }
+ catch (ObjectDisposedException)
+ {
+ message.Recycle();
+ return;
+ }
+ catch (SocketException sx)
+ {
+ message.Recycle();
+ if (sx.SocketErrorCode == SocketError.NotConnected)
+ {
+ this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected);
+ return;
+ }
+
+ // Client no longer reachable, pretend it didn't happen
+ // TODO should this not inform the connection this client is lost???
+
+ // This thread suggests the IP is not passed out from WinSoc so maybe not possible
+ // http://stackoverflow.com/questions/2576926/python-socket-error-on-udp-data-receive-10054
+ this.Logger?.WriteError($"Socket Ex {sx.SocketErrorCode} in ReadCallback: {sx.Message}");
+
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+ catch (Exception ex)
+ {
+ // Idk, maybe a null ref after dispose?
+ message.Recycle();
+ this.Logger?.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+
+ // I'm a little concerned about a infinite loop here, but it seems like it's possible
+ // to get 0 bytes read on UDP without the socket being shut down.
+ if (bytesReceived == 0)
+ {
+ message.Recycle();
+ this.Logger?.WriteInfo("Received 0 bytes");
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+
+ //Begin receiving again
+ StartListeningForData();
+
+ bool aware = true;
+ bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello;
+
+ // If we're aware of this connection use the one already
+ // If this is a new client then connect with them!
+ UdpServerConnection connection;
+ if (!this.allConnections.TryGetValue(remoteEndPoint, out connection))
+ {
+ lock (this.allConnections)
+ {
+ if (!this.allConnections.TryGetValue(remoteEndPoint, out connection))
+ {
+ // Check for malformed connection attempts
+ if (!isHello)
+ {
+ message.Recycle();
+ return;
+ }
+
+ if (AcceptConnection != null)
+ {
+ if (!AcceptConnection((IPEndPoint)remoteEndPoint, message.Buffer, out var response))
+ {
+ message.Recycle();
+ if (response != null)
+ {
+ SendData(response, response.Length, remoteEndPoint);
+ }
+
+ return;
+ }
+ }
+
+ aware = false;
+ connection = new UdpServerConnection(this, (IPEndPoint)remoteEndPoint, this.IPMode, this.Logger);
+ if (!this.allConnections.TryAdd(remoteEndPoint, connection))
+ {
+ throw new HazelException("Failed to add a connection. This should never happen.");
+ }
+ }
+ }
+ }
+
+ // If it's a new connection invoke the NewConnection event.
+ // This needs to happen before handling the message because in localhost scenarios, the ACK and
+ // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers
+ if (!aware)
+ {
+ // Skip header and hello byte;
+ message.Offset = 4;
+ message.Length = bytesReceived - 4;
+ message.Position = 0;
+ InvokeNewConnection(message, connection);
+ }
+
+ // Inform the connection of the buffer (new connections need to send an ack back to client)
+ connection.HandleReceive(message, bytesReceived);
+ }
+
+#if DEBUG
+ public int TestDropRate = -1;
+ private int dropCounter = 0;
+#endif
+
+ /// <summary>
+ /// Sends data from the listener socket.
+ /// </summary>
+ /// <param name="bytes">The bytes to send.</param>
+ /// <param name="endPoint">The endpoint to send to.</param>
+ internal void SendData(byte[] bytes, int length, EndPoint endPoint)
+ {
+ if (length > bytes.Length) return;
+
+#if DEBUG
+ if (TestDropRate > 0)
+ {
+ if (Interlocked.Increment(ref dropCounter) % TestDropRate == 0)
+ {
+ return;
+ }
+ }
+#endif
+
+ try
+ {
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ endPoint,
+ SendCallback,
+ null);
+
+ this.Statistics.AddBytesSent(length);
+ }
+ catch (SocketException e)
+ {
+ this.Logger?.WriteError("Could not send data as a SocketException occurred: " + e);
+ }
+ catch (ObjectDisposedException)
+ {
+ //Keep alive timer probably ran, ignore
+ return;
+ }
+ }
+
+ private void SendCallback(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Sends data from the listener socket.
+ /// </summary>
+ /// <param name="bytes">The bytes to send.</param>
+ /// <param name="endPoint">The endpoint to send to.</param>
+ internal void SendDataSync(byte[] bytes, int length, EndPoint endPoint)
+ {
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ endPoint
+ );
+
+ this.Statistics.AddBytesSent(length);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Removes a virtual connection from the list.
+ /// </summary>
+ /// <param name="endPoint">The endpoint of the virtual connection.</param>
+ internal void RemoveConnectionTo(EndPoint endPoint)
+ {
+ this.allConnections.TryRemove(endPoint, out var conn);
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ kvp.Value.Dispose();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.reliablePacketTimer.Dispose();
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs
new file mode 100644
index 0000000..ff5b29d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Net;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a servers's connection to a client that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ internal sealed class UdpServerConnection : UdpConnection
+ {
+ /// <summary>
+ /// The connection listener that we use the socket of.
+ /// </summary>
+ /// <remarks>
+ /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that
+ /// created this connection and is hence the listener this conenction sends and receives via.
+ /// </remarks>
+ public UdpConnectionListener Listener { get; private set; }
+
+ /// <summary>
+ /// Creates a UdpConnection for the virtual connection to the endpoint.
+ /// </summary>
+ /// <param name="listener">The listener that created this connection.</param>
+ /// <param name="endPoint">The endpoint that we are connected to.</param>
+ /// <param name="IPMode">The IPMode we are connected using.</param>
+ internal UdpServerConnection(UdpConnectionListener listener, IPEndPoint endPoint, IPMode IPMode, ILogger logger)
+ : base(logger)
+ {
+ this.Listener = listener;
+ this.EndPoint = endPoint;
+ this.IPMode = IPMode;
+
+ State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ this.Statistics.LogPacketSend(length);
+ Listener.SendData(bytes, length, EndPoint);
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state != ConnectionState.Connected)
+ {
+ return false;
+ }
+
+ this._state = ConnectionState.NotConnected;
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ Listener.SendDataSync(bytes, bytes.Length, EndPoint);
+ }
+ catch { }
+
+ return true;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ Listener.RemoveConnectionTo(EndPoint);
+
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs
new file mode 100644
index 0000000..8e6063d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs
@@ -0,0 +1,353 @@
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Unity doesn't always get along with thread pools well, so this interface will hopefully suit that case better.
+ /// </summary>
+ /// <inheritdoc/>
+ public class UnityUdpClientConnection : UdpConnection
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ private Socket socket;
+
+ public UnityUdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger)
+ {
+ this.EndPoint = remoteEndPoint;
+ this.IPMode = ipMode;
+
+ this.socket = CreateSocket(ipMode);
+ this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ExclusiveAddressUse, true);
+ }
+
+ ~UnityUdpClientConnection()
+ {
+ this.Dispose(false);
+ }
+
+ public void FixedUpdate()
+ {
+ try
+ {
+ ResendPacketsIfNeeded();
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("FixedUpdate: " + e);
+ }
+
+ try
+ {
+ ManageReliablePackets();
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("FixedUpdate: " + e);
+ }
+ }
+
+ protected virtual void RestartConnection()
+ {
+ }
+
+ protected virtual void ResendPacketsIfNeeded()
+ {
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+#if DEBUG
+ if (TestLagMs > 0)
+ {
+ ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); });
+ }
+ else
+#endif
+ {
+ WriteBytesToConnectionReal(bytes, length);
+ }
+ }
+
+ private void WriteBytesToConnectionReal(byte[] bytes, int length)
+ {
+ try
+ {
+ this.Statistics.LogPacketSend(length);
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint,
+ HandleSendTo,
+ null);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ /// <summary>
+ /// Synchronously writes the given bytes to the connection.
+ /// </summary>
+ /// <param name="bytes">The bytes to write.</param>
+ protected virtual void WriteBytesToConnectionSync(byte[] bytes, int length)
+ {
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ private void HandleSendTo(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.ConnectAsync(bytes);
+ for (int timer = 0; timer < timeout; timer += 100)
+ {
+ if (this.State != ConnectionState.Connecting) return;
+ Thread.Sleep(100);
+
+ // I guess if we're gonna block in Unity, then let's assume no one will pump this for us.
+ this.FixedUpdate();
+ }
+ }
+
+ /// <inheritdoc />
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connecting;
+
+ try
+ {
+ if (IPMode == IPMode.IPv4)
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ else
+ socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
+ }
+ catch (SocketException e)
+ {
+ this.State = ConnectionState.NotConnected;
+ throw new HazelException("A SocketException occurred while binding to the port.", e);
+ }
+
+ this.RestartConnection();
+
+ try
+ {
+ StartListeningForData();
+ }
+ catch (ObjectDisposedException)
+ {
+ // If the socket's been disposed then we can just end there but make sure we're in NotConnected state.
+ // If we end up here I'm really lost...
+ this.State = ConnectionState.NotConnected;
+ return;
+ }
+ catch (SocketException e)
+ {
+ Dispose();
+ throw new HazelException("A SocketException occurred while initiating a receive operation.", e);
+ }
+
+ // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present)
+ // When acknowledged set the state to connected
+ SendHello(bytes, () =>
+ {
+ this.InitializeKeepAliveTimer();
+ this.State = ConnectionState.Connected;
+ });
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ void StartListeningForData()
+ {
+ var msg = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ EndPoint ep = this.EndPoint;
+ socket.BeginReceiveFrom(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ref ep, ReadCallback, msg);
+ }
+ catch
+ {
+ msg.Recycle();
+ this.Dispose();
+ }
+ }
+
+ /// <summary>
+ /// Called when data has been received by the socket.
+ /// </summary>
+ /// <param name="result">The asyncronous operation's result.</param>
+ void ReadCallback(IAsyncResult result)
+ {
+#if DEBUG
+ if (this.TestLagMs > 0)
+ {
+ Thread.Sleep(this.TestLagMs);
+ }
+#endif
+
+ var msg = (MessageReader)result.AsyncState;
+
+ try
+ {
+ EndPoint ep = this.EndPoint;
+ msg.Length = socket.EndReceiveFrom(result, ref ep);
+ }
+ catch (SocketException e)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message);
+ return;
+ }
+ catch (ObjectDisposedException)
+ {
+ // Weirdly, it seems that this method can be called twice on the same AsyncState when object is disposed...
+ // So this just keeps us from hitting Duplicate Add errors at the risk of if this is a platform
+ // specific bug, we leak a MessageReader while the socket is disposing. Not a bad trade off.
+ return;
+ }
+ catch (Exception)
+ {
+ msg.Recycle();
+ return;
+ }
+
+ //Exit if no bytes read, we've failed.
+ if (msg.Length == 0)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes");
+ return;
+ }
+
+ //Begin receiving again
+ try
+ {
+ StartListeningForData();
+ }
+ catch (SocketException e)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message);
+ }
+ catch (ObjectDisposedException)
+ {
+ //If the socket's been disposed then we can just end there.
+ return;
+ }
+
+#if DEBUG
+ if (this.TestDropRate > 0)
+ {
+ if ((this.testDropCount++ % this.TestDropRate) == 0)
+ {
+ return;
+ }
+ }
+#endif
+
+ HandleReceive(msg, msg.Length);
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// You may include optional disconnect data. The SendOption must be unreliable.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state == ConnectionState.NotConnected) return false;
+ this._state = ConnectionState.NotConnected;
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ this.WriteBytesToConnectionSync(bytes, bytes.Length);
+ }
+ catch { }
+
+ return true;
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ base.Dispose(disposing);
+ }
+ }
+}