aboutsummaryrefslogtreecommitdiff
path: root/Tools/Hazel-Networking/Hazel/Dtls
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/Hazel-Networking/Hazel/Dtls')
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs147
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs1424
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs1246
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs734
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs63
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs66
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Record.cs123
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs159
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs202
11 files changed, 4332 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
new file mode 100644
index 0000000..65df39e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
@@ -0,0 +1,147 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// *_AES_128_GCM_* cipher suite
+ /// </summary>
+ public class Aes128GcmRecordProtection: IRecordProtection
+ {
+ private const int ImplicitNonceSize = 4;
+ private const int ExplicitNonceSize = 8;
+
+ private readonly Aes128Gcm serverWriteCipher;
+ private readonly Aes128Gcm clientWriteCipher;
+
+ private readonly ByteSpan serverWriteIV;
+ private readonly ByteSpan clientWriteIV;
+
+ /// <summary>
+ /// Create a new instance of the AES128_GCM record protection
+ /// </summary>
+ /// <param name="masterSecret">Shared secret</param>
+ /// <param name="serverRandom">Server random data</param>
+ /// <param name="clientRandom">Client random data</param>
+ public Aes128GcmRecordProtection(ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ ByteSpan combinedRandom = new byte[serverRandom.Length + clientRandom.Length];
+ serverRandom.CopyTo(combinedRandom);
+ clientRandom.CopyTo(combinedRandom.Slice(serverRandom.Length));
+
+ // Expand master_secret to encryption keys
+ const int ExpandedSize = 0
+ + 0 // mac_key_length
+ + 0 // mac_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + ImplicitNonceSize // fixed_iv_length
+ + ImplicitNonceSize // fixed_iv_length
+ ;
+
+ ByteSpan expandedKey = new byte[ExpandedSize];
+ PrfSha256.ExpandSecret(expandedKey, masterSecret, PrfLabel.KEY_EXPANSION, combinedRandom);
+
+ ByteSpan clientWriteKey = expandedKey.Slice(0, Aes128Gcm.KeySize);
+ ByteSpan serverWriteKey = expandedKey.Slice(Aes128Gcm.KeySize, Aes128Gcm.KeySize);
+ this.clientWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize, ImplicitNonceSize);
+ this.serverWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize + ImplicitNonceSize, ImplicitNonceSize);
+
+ this.serverWriteCipher = new Aes128Gcm(serverWriteKey);
+ this.clientWriteCipher = new Aes128Gcm(clientWriteKey);
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.serverWriteCipher.Dispose();
+ this.clientWriteCipher.Dispose();
+ }
+
+ /// <inheritdoc />
+ private static int GetEncryptedSizeImpl(int dataSize)
+ {
+ return dataSize + Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetEncryptedSize(int dataSize)
+ {
+ return GetEncryptedSizeImpl(dataSize);
+ }
+
+ private static int GetDecryptedSizeImpl(int dataSize)
+ {
+ return dataSize - Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetDecryptedSize(int dataSize)
+ {
+ return GetDecryptedSizeImpl(dataSize);
+ }
+
+ /// <inheritdoc />
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static void EncryptPlaintext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetEncryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)input.Length;
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ cipher.Seal(output, nonce, input, associatedData);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static bool DecryptCiphertext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetDecryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)GetDecryptedSizeImpl(input.Length);
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ return cipher.Open(output, nonce, input, associatedData);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
new file mode 100644
index 0000000..61f41d3
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
@@ -0,0 +1,1424 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+using Hazel.Udp.FewerThreads;
+using Hazel.Crypto;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Listens for new UDP-DTLS connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsConnectionListener : ThreadLimitedUdpConnectionListener
+ {
+ private const int MaxCertFragmentSizeV0 = 1200;
+
+ // Min MTU - UDP+IP header - 1 (for good measure. :))
+ private const int MaxCertFragmentSizeV1 = 576 - 32 - 1;
+
+ /// <summary>
+ /// Current state of handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ ExpectingHello,
+ ExpectingClientKeyExchange,
+ ExpectingChangeCipherSpec,
+ ExpectingFinish
+ }
+
+ /// <summary>
+ /// State to manage the current epoch `N`
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ public IRecordProtection PreviousRecordProtection;
+
+ // Need to keep these around so we can re-transmit our
+ // last handshake record flight
+ public ByteSpan ExpectedClientFinishedVerification;
+ public ByteSpan ServerFinishedVerification;
+ public ulong NextOutgoingSequenceForPreviousEpoch;
+ }
+
+ /// <summary>
+ /// State to manage the transition from the current
+ /// epoch `N` to epoch `N+1`
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+ public CipherSuite SelectedCipherSuite;
+
+ public ulong NextOutgoingSequence;
+
+ public IHandshakeCipherSuite Handshake;
+ public IRecordProtection RecordProtection;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public Sha256Stream VerificationStream;
+
+ public ByteSpan ClientVerification;
+ public ByteSpan ServerVerification;
+
+ }
+
+ /// <summary>
+ /// Per-peer state
+ /// </summary>
+ sealed class PeerData : IDisposable
+ {
+ public ushort Epoch;
+ public bool CanHandleApplicationData;
+
+ public HazelDtlsSessionInfo Session;
+
+ public CurrentEpoch CurrentEpoch;
+ public NextEpoch NextEpoch;
+
+ public ConnectionId ConnectionId;
+
+ public readonly List<ByteSpan> QueuedApplicationDataMessage = new List<ByteSpan>();
+ public readonly ConcurrentBag<MessageReader> ApplicationData = new ConcurrentBag<MessageReader>();
+ public readonly ProtocolVersion ProtocolVersion;
+
+ public DateTime StartOfNegotiation;
+
+ public PeerData(ConnectionId connectionId, ulong nextExpectedSequenceNumber, ProtocolVersion protocolVersion)
+ {
+ ByteSpan block = new byte[2 * Finished.Size];
+ this.CurrentEpoch.ServerFinishedVerification = block.Slice(0, Finished.Size);
+ this.CurrentEpoch.ExpectedClientFinishedVerification = block.Slice(Finished.Size, Finished.Size);
+ this.ProtocolVersion = protocolVersion;
+
+ ResetPeer(connectionId, nextExpectedSequenceNumber);
+ }
+
+ public void ResetPeer(ConnectionId connectionId, ulong nextExpectedSequenceNumber)
+ {
+ Dispose();
+
+ this.Epoch = 0;
+ this.CanHandleApplicationData = false;
+ this.QueuedApplicationDataMessage.Clear();
+
+ this.CurrentEpoch.NextOutgoingSequence = 2; // Account for our ClientHelloVerify
+ this.CurrentEpoch.NextExpectedSequence = nextExpectedSequenceNumber;
+ this.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.CurrentEpoch.RecordProtection = NullRecordProtection.Instance;
+ this.CurrentEpoch.PreviousRecordProtection = null;
+ this.CurrentEpoch.ServerFinishedVerification.SecureClear();
+ this.CurrentEpoch.ExpectedClientFinishedVerification.SecureClear();
+
+ this.NextEpoch.State = HandshakeState.ExpectingHello;
+ this.NextEpoch.RecordProtection = null;
+ this.NextEpoch.Handshake = null;
+ this.NextEpoch.ClientRandom = new byte[Random.Size];
+ this.NextEpoch.ServerRandom = new byte[Random.Size];
+ this.NextEpoch.VerificationStream = new Sha256Stream();
+ this.NextEpoch.ClientVerification = new byte[Finished.Size];
+ this.NextEpoch.ServerVerification = new byte[Finished.Size];
+
+ this.ConnectionId = connectionId;
+
+ this.StartOfNegotiation = DateTime.UtcNow;
+ }
+
+ public void Dispose()
+ {
+ this.CurrentEpoch.RecordProtection?.Dispose();
+ this.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ this.NextEpoch.RecordProtection?.Dispose();
+ this.NextEpoch.Handshake?.Dispose();
+ this.NextEpoch.VerificationStream?.Dispose();
+
+ while (this.ApplicationData.TryTake(out var msg))
+ {
+ try
+ {
+ msg.Recycle();
+ }
+ catch { }
+ }
+ }
+ }
+
+ private RandomNumberGenerator random;
+
+ // Private key component of certificate's public key
+ private ByteSpan encodedCertificate;
+ private RSA certificatePrivateKey;
+
+ // HMAC key to validate ClientHello cookie
+ private ThreadedHmacHelper hmacHelper;
+ private HMAC CurrentCookieHmac {
+ get
+ {
+ return hmacHelper.GetCurrentCookieHmacsForThread();
+ }
+ }
+ private HMAC PreviousCookieHmac
+ {
+ get
+ {
+ return hmacHelper.GetPreviousCookieHmacsForThread();
+ }
+ }
+
+ private ConcurrentStack<ConnectionId> staleConnections = new ConcurrentStack<ConnectionId>();
+ private readonly ConcurrentDictionary<IPEndPoint, PeerData> existingPeers = new ConcurrentDictionary<IPEndPoint, PeerData>();
+ public int PeerCount => this.existingPeers.Count;
+
+ // TODO: Move these into an DtlsErrorStatistics class
+ public int NonPeerNonHelloPacketsDropped;
+ public int NonVerifiedFinishedHandshake;
+ public int NonPeerVerifyHelloRequests;
+ public int PeerVerifyHelloRequests;
+
+ private int connectionSerial_unsafe = 0;
+
+ private Timer staleConnectionUpkeep;
+
+ /// <summary>
+ /// Create a new instance of the DTLS listener
+ /// </summary>
+ /// <param name="numWorkers"></param>
+ /// <param name="endPoint"></param>
+ /// <param name="logger"></param>
+ /// <param name="ipMode"></param>
+ public DtlsConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ : base(numWorkers, endPoint, logger, ipMode)
+ {
+ this.random = RandomNumberGenerator.Create();
+
+ this.staleConnectionUpkeep = new Timer(this.HandleStaleConnections, null, 2500, 1000);
+ this.hmacHelper = new ThreadedHmacHelper(logger);
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ this.staleConnectionUpkeep.Dispose();
+
+ this.random?.Dispose();
+ this.random = null;
+
+ this.hmacHelper?.Dispose();
+ this.hmacHelper = null;
+
+ foreach (var pair in this.existingPeers)
+ {
+ pair.Value.Dispose();
+ }
+ this.existingPeers.Clear();
+ }
+
+ /// <summary>
+ /// Set the certificate key pair for the listener
+ /// </summary>
+ /// <param name="certificate">Certificate for the server</param>
+ public void SetCertificate(X509Certificate2 certificate)
+ {
+ if (!certificate.HasPrivateKey)
+ {
+ throw new ArgumentException("Certificate must have a private key attached", nameof(certificate));
+ }
+
+ RSA privateKey = certificate.GetRSAPrivateKey();
+ if (privateKey == null)
+ {
+ throw new ArgumentException("Certificate must be signed by an RSA key", nameof(certificate));
+ }
+
+ this.certificatePrivateKey?.Dispose();
+ this.certificatePrivateKey = privateKey;
+
+ this.encodedCertificate = Certificate.Encode(certificate);
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network.
+ ///
+ /// This is primarily a wrapper around ProcessIncomingMessage
+ /// to ensure `reader.Recycle()` is always called
+ /// </summary>
+ protected override void ReadCallback(MessageReader reader, IPEndPoint peerAddress, ConnectionId connectionId)
+ {
+ try
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ this.ProcessIncomingMessage(message, peerAddress);
+ }
+ finally
+ {
+ reader.Recycle();
+ }
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network
+ /// </summary>
+ private void ProcessIncomingMessage(ByteSpan message, IPEndPoint peerAddress)
+ {
+ PeerData peer = null;
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ lock (this.existingPeers)
+ {
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ HandleNonPeerRecord(message, peerAddress);
+ return;
+ }
+ }
+ }
+
+ ConnectionId peerConnectionId;
+
+ lock (peer)
+ {
+ peerConnectionId = peer.ConnectionId;
+
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (message.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, peer.ProtocolVersion, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}` Length({record.Length}) AvailableBytes({message.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = message.Slice(0, record.Length);
+ message = message.Slice(record.Length);
+
+ // Early-out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && !peer.CanHandleApplicationData)
+ {
+ this.Logger.WriteInfo($"Dropping ApplicationData record from `{peerAddress}` Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != peer.Epoch)
+ {
+ // Handle existing client negotiating a new connection
+ if (record.Epoch == 0 && record.ContentType == ContentType.Handshake)
+ {
+ ByteSpan handshakePayload = recordPayload;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ handshakePayload = handshakePayload.Slice(Handshake.Size);
+
+ if (handshake.FragmentOffset != 0 || handshake.Length != handshake.FragmentLength)
+ {
+ this.Logger.WriteError($"Dropping fragmented re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+ this.Logger.WriteVerbose($"Dropping non-ClientHello re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshakePayload.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`: Length({handshake.Length}) AvailableBytes({handshakePayload.Length})");
+ }
+
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, recordPayload, handshakePayload))
+ {
+ return;
+ }
+ continue;
+ }
+
+ this.Logger.WriteVerbose($"Dropping bad-epoch record from `{peerAddress}` RecordEpoch({record.Epoch}) CurrentEpoch({peer.Epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(peer.CurrentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < peer.CurrentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.Logger.WriteInfo($"Dropping too-old record from `{peerAddress}` Sequence({record.SequenceNumber}) Expected({peer.CurrentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((peer.CurrentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.Logger.WriteInfo($"Dropping duplicate record from `{peerAddress}`");
+ continue;
+ }
+ }
+
+ // Validate record authenticity
+ int decryptedSize = peer.CurrentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ if (decryptedSize < 0)
+ {
+ this.Logger.WriteInfo($"Dropping malformed record: Length {recordPayload.Length} Decrypted length: {decryptedSize}");
+ continue;
+ }
+
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ if (!peer.CurrentEpoch.RecordProtection.DecryptCiphertextFromClient(decryptedPayload, recordPayload, ref record))
+ {
+ this.Logger.WriteVerbose($"Dropping non-authentic {record.ContentType} record from `{peerAddress}`");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update our squence number bookeeping
+ if (record.SequenceNumber >= peer.CurrentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - peer.CurrentEpoch.NextExpectedSequence);
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ peer.CurrentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.Logger.WriteVerbose($"Record type {record.ContentType} ({peer.NextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.Logger.WriteError($"Dropping unexpected ChangeChiperSpec record from `{peerAddress}` State({peer.NextEpoch.State})");
+ break;
+ }
+ else if (peer.NextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+
+ this.Logger.WriteError($"Dropping ChangeCipherSpec message from `{peerAddress}`: No pending record protection");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed ChangeCipherSpec message from `{peerAddress}`");
+ break;
+ }
+
+ // Migrate to the next epoch
+ peer.Epoch = peer.NextEpoch.Epoch;
+ peer.CanHandleApplicationData = false; // Need a Finished message
+ peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch = peer.CurrentEpoch.NextOutgoingSequence;
+ peer.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ peer.CurrentEpoch.PreviousRecordProtection = peer.CurrentEpoch.RecordProtection;
+ peer.CurrentEpoch.RecordProtection = peer.NextEpoch.RecordProtection;
+ peer.CurrentEpoch.NextOutgoingSequence = 1;
+ peer.CurrentEpoch.NextExpectedSequence = 1;
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ peer.NextEpoch.ClientVerification.CopyTo(peer.CurrentEpoch.ExpectedClientFinishedVerification);
+ peer.NextEpoch.ServerVerification.CopyTo(peer.CurrentEpoch.ServerFinishedVerification);
+
+ peer.NextEpoch.State = HandshakeState.ExpectingHello;
+ peer.NextEpoch.Handshake?.Dispose();
+ peer.NextEpoch.Handshake = null;
+ peer.NextEpoch.NextOutgoingSequence = 1;
+ peer.NextEpoch.RecordProtection = null;
+ peer.NextEpoch.VerificationStream.Reset();
+ peer.NextEpoch.ClientVerification.SecureClear();
+ peer.NextEpoch.ServerVerification.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.Logger.WriteError($"Dropping unsupported Alert record from `{peerAddress}`");
+ break;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(peer, peerAddress, ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ peer.ApplicationData.Add(reader);
+ break;
+ }
+ }
+ }
+
+ // The peer lock must be exited before leaving the DtlsConnectionListener context to prevent deadlocks
+ // because ApplicationData processing may reenter this context
+ while (peer.ApplicationData.TryTake(out var appMsg))
+ {
+ base.ReadCallback(appMsg, peerAddress, peerConnectionId);
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer's network address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(PeerData peer, IPEndPoint peerAddress, ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple handshake payloads
+ while (message.Length > 0)
+ {
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (message.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ ByteSpan payload = message.Slice(0, (int)message.Length);
+ message = message.Slice((int)handshake.Length);
+ originalMessage = originalMessage.Slice(0, Handshake.Size + (int)handshake.Length);
+
+ // We do not support fragmented handshake messages
+ // from the client
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping fragmented Handshake message from `{peerAddress}` Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ ByteSpan packet;
+ ByteSpan writer;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"Received handshake {handshake.MessageType} ({peer.NextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.ClientHello:
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, originalMessage, payload))
+ {
+ return false;
+ }
+ break;
+
+ case HandshakeType.ClientKeyExchange:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ this.Logger.WriteError($"Dropping unexpected ClientKeyExchange message form `{peerAddress}` State({peer.NextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 5)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientKeyExchange message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[peer.NextEpoch.Handshake.SharedKeySize()];
+ if (!peer.NextEpoch.Handshake.VerifyClientMessageAndGenerateSharedKey(sharedSecret, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientKeyExchange message from `{peerAddress}`");
+ return false;
+ }
+
+ // Record incoming ClientKeyExchange message
+ // to verification stream
+ peer.NextEpoch.VerificationStream.AddData(originalMessage);
+
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ peer.NextEpoch.ClientRandom.CopyTo(randomSeed);
+ peer.NextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create the record protection for the upcoming epoch
+ switch (peer.NextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ peer.NextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , peer.NextEpoch.ServerRandom
+ , peer.NextEpoch.ClientRandom);
+ break;
+
+ default:
+ Debug.Assert(false, $"How did we agree to a cipher suite {peer.NextEpoch.SelectedCipherSuite} we can't create?");
+ this.Logger.WriteError($"Dropping ClientKeyExchange message from `{peerAddress}` Unsuppored cipher suite");
+ return false;
+ }
+
+ // Generate verification signatures
+ ByteSpan handshakeStreamHash = new byte[Sha256Stream.DigestSize];
+ peer.NextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeStreamHash);
+
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ClientVerification
+ , masterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeStreamHash
+ );
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ServerVerification
+ , masterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeStreamHash
+ );
+
+
+ // Update handshake state
+ masterSecret.SecureClear();
+ peer.NextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ break;
+
+ case HandshakeType.Finished:
+ // Unlike other handshake messages, this is
+ // for the current epoch - not the next epoch
+
+ // Cannot process a Finished message for
+ // epoch 0
+ if (peer.Epoch == 0)
+ {
+ this.Logger.WriteError($"Dropping Finished message for 0-epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message when we
+ // are negotiating the next epoch
+ else if (peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ this.Logger.WriteError($"Dropping Finished message while negotiating new epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message without
+ // verify data
+ else if (peer.CurrentEpoch.ExpectedClientFinishedVerification.Length != Finished.Size || peer.CurrentEpoch.ServerFinishedVerification.Length != Finished.Size)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch without verify data?");
+
+ this.Logger.WriteError($"Dropping Finished message (no verify data) from `{peerAddress}`");
+ return false;
+ }
+ // Cannot process a Finished message without
+ // record protection for the previous epoch
+ else if (peer.CurrentEpoch.PreviousRecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch with record protection for the previous epoch?");
+
+ this.Logger.WriteError($"Dropping Finished message from `{peerAddress}`: No previous epoch record protection");
+ return false;
+ }
+
+ // Verify message sequence
+ if (handshake.MessageSequence != 6)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence Finished message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the client has the correct
+ // handshake sequence
+ if (payload.Length != Finished.Size)
+ {
+ this.Logger.WriteError($"Dropping malformed Finished message from `{peerAddress}`");
+ return false;
+ }
+ else if (1 != Crypto.Const.ConstantCompareSpans(payload, peer.CurrentEpoch.ExpectedClientFinishedVerification))
+ {
+
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-verified Finished Handshake from `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonVerifiedFinishedHandshake);
+#endif
+
+ // Abort the connection here
+ //
+ // The client is either broken, or
+ // doen not agree on our epoch settings.
+ //
+ // Either way, there is not a feasible
+ // way to progress the connection.
+ MarkConnectionAsStale(peer.ConnectionId);
+ this.existingPeers.TryRemove(peerAddress, out _);
+
+ return false;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Describe our ChangeCipherSpec+Finished
+ Handshake outgoingHandshake = new Handshake();
+ outgoingHandshake.MessageType = HandshakeType.Finished;
+ outgoingHandshake.Length = Finished.Size;
+ outgoingHandshake.MessageSequence = 7;
+ outgoingHandshake.FragmentOffset = 0;
+ outgoingHandshake.FragmentLength = outgoingHandshake.Length;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = protocolVersion;
+ changeCipherSpecRecord.Epoch = (ushort)(peer.Epoch - 1);
+ changeCipherSpecRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+ changeCipherSpecRecord.Length = (ushort)peer.CurrentEpoch.PreviousRecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ int plaintextFinishedPayloadSize = Handshake.Size + (int)outgoingHandshake.Length;
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = protocolVersion;
+ finishedRecord.Epoch = peer.Epoch;
+ finishedRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(plaintextFinishedPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the flight into wire format
+ packet = new byte[Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length];
+ writer = packet;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+
+ ByteSpan startOfFinishedRecord = packet.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ outgoingHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.CurrentEpoch.ServerFinishedVerification.CopyTo(writer);
+
+ // Protect the ChangeChipherSpec record
+ peer.CurrentEpoch.PreviousRecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, changeCipherSpecRecord.Length),
+ packet.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished Handshake record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, plaintextFinishedPayloadSize),
+ ref finishedRecord
+ );
+
+ // Current epoch can now handle application data
+ peer.CanHandleApplicationData = true;
+
+ base.QueueRawData(packet, peerAddress);
+ break;
+
+ // Drop messages that we do not support
+ case HandshakeType.CertificateVerify:
+ this.Logger.WriteError($"Dropping unsupported Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+
+ // Drop messages that originate from the server
+ case HandshakeType.HelloRequest:
+ case HandshakeType.ServerHello:
+ case HandshakeType.HelloVerifyRequest:
+ case HandshakeType.Certificate:
+ case HandshakeType.ServerKeyExchange:
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.ServerHelloDone:
+ this.Logger.WriteError($"Dropping server Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle a ClientHello message for a peer
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="handshake">Parent Handshake header</param>
+ /// <param name="payload">Handshake payload</param>
+ private bool HandleClientHello(PeerData peer, IPEndPoint peerAddress, ref Record record, ref Handshake handshake, ByteSpan originalMessage, ByteSpan payload)
+ {
+ // Verify message sequence
+ if (handshake.MessageSequence != 0)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientHello from `{peerAddress}` MessageSequence({handshake.MessageSequence})`");
+ return true;
+ }
+
+ // Make sure we can handle a ClientHello message
+ if (peer.NextEpoch.State != HandshakeState.ExpectingHello && peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ // Always handle ClientHello for epoch 0
+ if (record.Epoch != 0)
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peer}` Not expecting ClientHello");
+ return true;
+ }
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+ if (!ClientHello.Parse(out ClientHello clientHello, protocolVersion, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ // Find an acceptable cipher suite we can use
+ CipherSuite selectedCipherSuite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256;
+ if (!clientHello.ContainsCipherSuite(CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256) || !clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` No compatible cipher suite");
+ return false;
+ }
+
+ // If this message was not signed by us,
+ // request a signed message before doing anything else
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+ ulong outgoingSequence = 1;
+ IRecordProtection recordProtection = NullRecordProtection.Instance;
+ if (record.Epoch != 0)
+ {
+ outgoingSequence = peer.CurrentEpoch.NextExpectedSequence;
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ recordProtection = peer.CurrentEpoch.RecordProtection;
+ }
+
+#if DEBUG
+ this.Logger.WriteError($"Sending HelloVerifyRequest to peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.PeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, outgoingSequence, record.Epoch, recordProtection, protocolVersion);
+ return true;
+ }
+ }
+
+ // Client is initiating a brand new connection. We need
+ // to destroy the existing connection and establish a
+ // new session.
+ if (record.Epoch == 0 && peer.Epoch != 0)
+ {
+ ConnectionId oldConnectionId = peer.ConnectionId;
+ peer.ResetPeer(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1);
+
+ // Inform the parent layer that the existing
+ // connection should be abandoned.
+ MarkConnectionAsStale(oldConnectionId);
+ }
+
+ // Determine if this is an original message, or a retransmission
+ bool recordMessagesForVerifyData = false;
+ if (peer.NextEpoch.State == HandshakeState.ExpectingHello)
+ {
+ // Create our handhake cipher suite
+ IHandshakeCipherSuite handshakeCipherSuite = null;
+ switch (selectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ if (clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ handshakeCipherSuite = new X25519EcdheRsaSha256(this.random);
+ }
+ else
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 cipher suite");
+ return false;
+ }
+
+ break;
+
+ default:
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create handshake cipher suite");
+ return false;
+ }
+
+ peer.Session = clientHello.Session;
+
+ // Update the state of our epoch transition
+ peer.NextEpoch.Epoch = (ushort)(record.Epoch + 1);
+ peer.NextEpoch.State = HandshakeState.ExpectingClientKeyExchange;
+ peer.NextEpoch.SelectedCipherSuite = selectedCipherSuite;
+ peer.NextEpoch.Handshake = handshakeCipherSuite;
+ clientHello.Random.CopyTo(peer.NextEpoch.ClientRandom);
+ peer.NextEpoch.ServerRandom.FillWithRandom(this.random);
+ recordMessagesForVerifyData = true;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"ClientRandom: {peer.NextEpoch.ClientRandom} ServerRandom: {peer.NextEpoch.ServerRandom}");
+#endif
+
+ // Copy the original ClientHello
+ // handshake to our verification stream
+ peer.NextEpoch.VerificationStream.AddData(
+ originalMessage.Slice(
+ 0
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+ }
+
+ // The initial record flight from the server
+ // contains the following Handshake messages:
+ // * ServerHello
+ // * Certificate
+ // * ServerKeyExchange
+ // * ServerHelloDone
+ //
+ // The Certificate message is almost always
+ // too large to fit into a single datagram,
+ // so it is pre-fragmented
+ // (see `SetCertificates`). Therefore, we
+ // need to send multiple record packets for
+ // this flight.
+ //
+ // The first record contains the ServerHello
+ // handshake message, as well as the first
+ // portion of the Certificate message.
+ //
+ // We then send a record packet until the
+ // entire Certificate message has been sent
+ // to the client.
+ //
+ // The final record packet contains the
+ // ServerKeyExchange and the ServerHelloDone
+ // messages.
+
+ // Describe first record of the flight
+ ServerHello serverHello = new ServerHello();
+ serverHello.ServerProtocolVersion = protocolVersion;
+ serverHello.Random = peer.NextEpoch.ServerRandom;
+ serverHello.CipherSuite = selectedCipherSuite;
+
+ Handshake serverHelloHandshake = new Handshake();
+ serverHelloHandshake.MessageType = HandshakeType.ServerHello;
+ serverHelloHandshake.Length = ServerHello.MinSize;
+ serverHelloHandshake.MessageSequence = 1;
+ serverHelloHandshake.FragmentOffset = 0;
+ serverHelloHandshake.FragmentLength = serverHelloHandshake.Length;
+
+ int maxCertFragmentSize = peer.Session.Version == 0 ? MaxCertFragmentSizeV0 : MaxCertFragmentSizeV1;
+
+ // The first certificate data needs to leave room for
+ // * Record header
+ // * ServerHello header
+ // * ServerHello payload
+ // * Certificate header
+
+ var certificateData = this.encodedCertificate;
+ int initialCertPadding = Record.Size + Handshake.Size + serverHello.Size + Handshake.Size;
+ int certInitialFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - initialCertPadding);
+
+ Handshake certificateHandshake = new Handshake();
+ certificateHandshake.MessageType = HandshakeType.Certificate;
+ certificateHandshake.Length = (uint)certificateData.Length;
+ certificateHandshake.MessageSequence = 2;
+ certificateHandshake.FragmentOffset = 0;
+ certificateHandshake.FragmentLength = (uint)certInitialFragmentSize;
+
+ int initialRecordPayloadSize = 0
+ + Handshake.Size + serverHello.Size
+ + Handshake.Size + (int)certificateHandshake.FragmentLength
+ ;
+ Record initialRecord = new Record();
+ initialRecord.ContentType = ContentType.Handshake;
+ initialRecord.ProtocolVersion = protocolVersion;
+ initialRecord.Epoch = peer.Epoch;
+ initialRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ initialRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(initialRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert initial record of the flight to
+ // wire format
+ ByteSpan packet = new byte[Record.Size + initialRecord.Length];
+ ByteSpan writer = packet;
+ initialRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certInitialFragmentSize).CopyTo(writer);
+ certificateData = certificateData.Slice(certInitialFragmentSize);
+
+ // Protect initial record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, initialRecord.Length),
+ packet.Slice(Record.Size, initialRecordPayloadSize),
+ ref initialRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ Handshake fullCeritficateHandshake = certificateHandshake;
+ fullCeritficateHandshake.FragmentLength = fullCeritficateHandshake.Length;
+
+ packet = new byte[Handshake.Size + ServerHello.MinSize + Handshake.Size];
+ writer = packet;
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ fullCeritficateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ peer.NextEpoch.VerificationStream.AddData(packet);
+ peer.NextEpoch.VerificationStream.AddData(this.encodedCertificate);
+ }
+
+ // Process additional certificate records
+ // Subsequent certificate data needs to leave room for
+ // * Record header
+ // * Certificate header
+ const int CertPadding = Record.Size + Handshake.Size;
+ while (certificateData.Length > 0)
+ {
+ int certFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - CertPadding);
+
+ certificateHandshake.FragmentOffset += certificateHandshake.FragmentLength;
+ certificateHandshake.FragmentLength = (uint)certFragmentSize;
+
+ int additionalRecordPayloadSize = Handshake.Size + (int)certificateHandshake.FragmentLength;
+ Record additionalRecord = new Record();
+ additionalRecord.ContentType = ContentType.Handshake;
+ additionalRecord.ProtocolVersion = protocolVersion;
+ additionalRecord.Epoch = peer.Epoch;
+ additionalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ additionalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(additionalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert record to wire format
+ packet = new byte[Record.Size + additionalRecord.Length];
+ writer = packet;
+ additionalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certFragmentSize).CopyTo(writer);
+
+ certificateData = certificateData.Slice(certFragmentSize);
+
+ // Protect record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, additionalRecord.Length),
+ packet.Slice(Record.Size, additionalRecordPayloadSize),
+ ref additionalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ // Describe final record of the flight
+ Handshake serverKeyExchangeHandshake = new Handshake();
+ serverKeyExchangeHandshake.MessageType = HandshakeType.ServerKeyExchange;
+ serverKeyExchangeHandshake.Length = (uint)peer.NextEpoch.Handshake.CalculateServerMessageSize(this.certificatePrivateKey);
+ serverKeyExchangeHandshake.MessageSequence = 3;
+ serverKeyExchangeHandshake.FragmentOffset = 0;
+ serverKeyExchangeHandshake.FragmentLength = serverKeyExchangeHandshake.Length;
+
+ Handshake serverHelloDoneHandshake = new Handshake();
+ serverHelloDoneHandshake.MessageType = HandshakeType.ServerHelloDone;
+ serverHelloDoneHandshake.Length = 0;
+ serverHelloDoneHandshake.MessageSequence = 4;
+ serverHelloDoneHandshake.FragmentOffset = 0;
+ serverHelloDoneHandshake.FragmentLength = 0;
+
+ int finalRecordPayloadSize = 0
+ + Handshake.Size + (int)serverKeyExchangeHandshake.Length
+ + Handshake.Size + (int)serverHelloDoneHandshake.Length
+ ;
+ Record finalRecord = new Record();
+ finalRecord.ContentType = ContentType.Handshake;
+ finalRecord.ProtocolVersion = protocolVersion;
+ finalRecord.Epoch = peer.Epoch;
+ finalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(finalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert final record of the flight to wire
+ // format
+ packet = new byte[Record.Size + finalRecord.Length];
+ writer = packet;
+ finalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverKeyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.NextEpoch.Handshake.EncodeServerKeyExchangeMessage(writer, this.certificatePrivateKey);
+ writer = writer.Slice((int)serverKeyExchangeHandshake.Length);
+ serverHelloDoneHandshake.Encode(writer);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ peer.NextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ packet.Offset + Record.Size
+ , finalRecordPayloadSize
+ )
+ );
+ }
+
+ // Protect final record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, finalRecord.Length),
+ packet.Slice(Record.Size, finalRecordPayloadSize),
+ ref finalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle an incoming packet that is not tied to an existing peer
+ /// </summary>
+ /// <param name="message">Incoming datagram</param>
+ /// <param name="peerAddress">Originating address</param>
+ private void HandleNonPeerRecord(ByteSpan message, IPEndPoint peerAddress)
+ {
+ Record record;
+ if (!Record.Parse(out record, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from non-peer `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ // The protocol only supports receiving a single record
+ // from a non-peer.
+ if (record.Length != message.Length)
+ {
+ // NOTE(mendsley): This isn't always fatal.
+ // However, this is an indication that something
+ // fishy is going on. In the best case, there's a
+ // bug on the client or in the UDP stack (some
+ // stacks don't both to verify the checksum). In the
+ // worst case we're dealing with a malicious actor.
+ // In the malicious case, we'll end up dropping the
+ // connection later in the process.
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteInfo($"Dropping bad record from non-peer `{peerAddress}`. Msg length {message.Length} < {record.Length}");
+ return;
+ }
+ }
+
+ // We only accept zero-epoch records from non-peers
+ if (record.Epoch != 0)
+ {
+ ///NOTE(mendsley): Not logging anything here, as
+ /// this could easily be latent data arriving from a
+ /// recently disconnected peer.
+ return;
+ }
+
+ // We only accept Handshake protocol messages from non-peers
+ if (record.ContentType != ContentType.Handshake)
+ {
+ this.Logger.WriteError($"Dropping non-handhsake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed handshake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // We only accept ClientHello messages from non-peers
+ if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-ClientHello ({handshake.MessageType}) message from non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerNonHelloPacketsDropped);
+#endif
+ return;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (!ClientHello.Parse(out ClientHello clientHello, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // If this ClientHello is not signed by us, request the
+ // client send us a signed message
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+#if DEBUG
+ this.Logger.WriteVerbose($"Sending HelloVerifyRequest to non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, 1, 0, NullRecordProtection.Instance, clientHello.ClientProtocolVersion);
+ return;
+ }
+ }
+
+ // Allocate state for the new peer and register it
+ PeerData peer = new PeerData(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1, clientHello.ClientProtocolVersion);
+ this.ProcessHandshake(peer, peerAddress, ref record, originalMessage);
+ this.existingPeers[peerAddress] = peer;
+ }
+
+ //Send a HelloVerifyRequest handshake message to a peer
+ private void SendHelloVerifyRequest(IPEndPoint peerAddress, ulong recordSequence, ushort epoch, IRecordProtection recordProtection, ProtocolVersion protocolVersion)
+ {
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.HelloVerifyRequest;
+ handshake.Length = HelloVerifyRequest.Size;
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ int plaintextPayloadSize = Handshake.Size + (int)handshake.Length;
+
+ Record record = new Record();
+ record.ContentType = ContentType.Handshake;
+ record.ProtocolVersion = protocolVersion;
+ record.Epoch = epoch;
+ record.SequenceNumber = recordSequence;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(plaintextPayloadSize);
+
+ // Encode record to wire format
+ ByteSpan packet = new byte[Record.Size + record.Length];
+ ByteSpan writer = packet;
+ record.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ HelloVerifyRequest.Encode(writer, peerAddress, this.CurrentCookieHmac, protocolVersion);
+
+ // Protect record payload
+ recordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, record.Length),
+ packet.Slice(Record.Size, plaintextPayloadSize),
+ ref record
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ /// <summary>
+ /// Handle a requrest to send a datagram to the network
+ /// </summary>
+ protected override void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint)
+ {
+ PeerData peer;
+ if (!this.existingPeers.TryGetValue(remoteEndPoint, out peer))
+ {
+ // Drop messages if we don't know how to send them
+ return;
+ }
+
+ lock (peer)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ ByteSpan copyOfSpan = new byte[span.Length];
+ span.CopyTo(copyOfSpan);
+
+ peer.QueuedApplicationDataMessage.Add(copyOfSpan);
+ return;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Send any queued application data now
+ for (int ii = 0, nn = peer.QueuedApplicationDataMessage.Count; ii != nn; ++ii)
+ {
+ ByteSpan queuedSpan = peer.QueuedApplicationDataMessage[ii];
+
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(queuedSpan.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ queuedSpan.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, queuedSpan.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ peer.QueuedApplicationDataMessage.Clear();
+
+ {
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(span.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ span.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, span.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ }
+ }
+
+ private void HandleStaleConnections(object _)
+ {
+ TimeSpan maxAge = TimeSpan.FromSeconds(2.5f);
+ DateTime now = DateTime.UtcNow;
+ foreach (KeyValuePair<IPEndPoint, PeerData> kvp in this.existingPeers)
+ {
+ PeerData peer = kvp.Value;
+ lock (peer)
+ {
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ TimeSpan negotiationAge = now - peer.StartOfNegotiation;
+ if (negotiationAge > maxAge)
+ {
+ MarkConnectionAsStale(peer.ConnectionId);
+ }
+ }
+ }
+ }
+
+ ConnectionId connectionId;
+ while (this.staleConnections.TryPop(out connectionId))
+ {
+ ThreadLimitedUdpServerConnection connection;
+ if (this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ connection.Disconnect("Stale Connection", null);
+ }
+ }
+ }
+
+ protected void MarkConnectionAsStale(ConnectionId connectionId)
+ {
+ if (this.allConnections.ContainsKey(connectionId))
+ {
+ this.staleConnections.Push(connectionId);
+ }
+ }
+
+ /// <inheritdoc />
+ internal override void RemovePeerRecord(ConnectionId connectionId)
+ {
+ if (this.existingPeers.TryRemove(connectionId.EndPoint, out var peer))
+ {
+ peer.Dispose();
+ }
+ }
+
+ /// <summary>
+ /// Allocate a new connection id
+ /// </summary>
+ private ConnectionId AllocateConnectionId(IPEndPoint endPoint)
+ {
+ int rawSerialId = Interlocked.Increment(ref this.connectionSerial_unsafe);
+ return ConnectionId.Create(endPoint, rawSerialId);
+ }
+
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
new file mode 100644
index 0000000..4da2051
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
@@ -0,0 +1,1246 @@
+using Hazel.Crypto;
+using Hazel.Udp;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Connects to a UDP-DTLS server
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsUnityConnection : UnityUdpClientConnection
+ {
+ /// <summary>
+ /// Current state of the handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ Initializing,
+
+ ExpectingServerHello,
+ ExpectingCertificate,
+ ExpectingServerKeyExchange,
+ ExpectingServerHelloDone,
+
+ ExpectingChangeCipherSpec,
+ ExpectingFinished,
+
+ Established,
+ }
+
+ /// <summary>
+ /// State data for the current epoch
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ }
+
+ struct FragmentRange
+ {
+ public int Offset;
+ public int Length;
+ }
+
+ /// <summary>
+ /// State data for the next epoch
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+
+ public ulong NextOutgoingSequence;
+
+ public DateTime NegotiationStartTime;
+ public DateTime NextPacketResendTime;
+ public int PacketResendCount;
+
+ public CipherSuite SelectedCipherSuite;
+ public IRecordProtection RecordProtection;
+ public IHandshakeCipherSuite Handshake;
+ public ByteSpan Cookie;
+ public Sha256Stream VerificationStream;
+ public RSA ServerPublicKey;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public ByteSpan MasterSecret;
+ public ByteSpan ServerVerification;
+
+ public List<FragmentRange> CertificateFragments;
+ public ByteSpan CertificatePayload;
+ }
+
+ struct QueuedAppData
+ {
+ public byte[] Bytes;
+ public byte SendOption;
+ public Action AckCallback;
+ }
+
+ private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP;
+
+ internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion;
+
+ private readonly object syncRoot = new object();
+ private readonly RandomNumberGenerator random = RandomNumberGenerator.Create();
+
+ private ushort epoch;
+ private CurrentEpoch currentEpoch;
+ private NextEpoch nextEpoch;
+ private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200);
+
+ private readonly Queue<QueuedAppData> queuedApplicationData = new Queue<QueuedAppData>();
+
+ private X509Certificate2Collection serverCertificates = new X509Certificate2Collection();
+
+ public bool HandshakeComplete
+ {
+ get
+ {
+ lock (this.syncRoot)
+ {
+ return this.nextEpoch.State == HandshakeState.Established;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Create a new instance of the DTLS connection
+ /// </summary>
+ /// <inheritdoc />
+ public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger, remoteEndPoint, ipMode)
+ {
+ this.nextEpoch.ServerRandom = new byte[Random.Size];
+ this.nextEpoch.ClientRandom = new byte[Random.Size];
+ this.nextEpoch.ServerVerification = new byte[Finished.Size];
+ this.nextEpoch.CertificateFragments = new List<FragmentRange>();
+
+ this.ResetConnectionState();
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ }
+ }
+
+ /// <summary>
+ /// Set the list of valid server certificates
+ /// </summary>
+ /// <param name="certificateCollection">
+ /// List of certificates of authentic servers
+ /// </param>
+ public void SetValidServerCertificates(X509Certificate2Collection certificateCollection)
+ {
+ lock (this.syncRoot)
+ {
+ foreach (X509Certificate2 certificate in certificateCollection)
+ {
+ if (!(certificate.PublicKey.Key is RSA))
+ {
+ throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection));
+ }
+ }
+
+ this.serverCertificates = certificateCollection;
+ }
+ }
+
+ /// <summary>
+ /// Set the packet resend timer for handshake messages
+ /// </summary>
+ public void SetHandshakeResendTimeout(TimeSpan timeout)
+ {
+ lock (this.syncRoot)
+ {
+ this.handshakeResendTimeout = timeout;
+ }
+ }
+
+ /// <summary>
+ /// Reset existing connection state
+ /// </summary>
+ private void ResetConnectionState()
+ {
+ this.currentEpoch.NextOutgoingSequence = 1;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.currentEpoch.RecordProtection?.Dispose();
+ this.currentEpoch.RecordProtection = NullRecordProtection.Instance;
+
+ this.nextEpoch.Epoch = 1;
+ this.nextEpoch.State = HandshakeState.Initializing;
+ this.nextEpoch.NextOutgoingSequence = 1;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection?.Dispose();
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Handshake = null;
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream?.Dispose();
+ this.nextEpoch.VerificationStream = new Sha256Stream();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+ this.epoch = 0;
+ while (this.queuedApplicationData.TryDequeue(out _)) ;
+ }
+
+ /// <summary>
+ /// Abort the existing connection and restart the process
+ /// </summary>
+ protected override void RestartConnection()
+ {
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+ this.SendClientHello(isRetransmit: false);
+ }
+
+ base.RestartConnection();
+ }
+
+ /// <inheritdoc />
+ protected override void ResendPacketsIfNeeded()
+ {
+ lock (this.syncRoot)
+ {
+ // Check if we need to resend handshake message
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ DateTime now = DateTime.UtcNow;
+ if (now >= this.nextEpoch.NextPacketResendTime)
+ {
+ double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds;
+ this.nextEpoch.PacketResendCount++;
+
+ if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit)
+ || negotiationDurationMs > this.DisconnectTimeoutMs)
+ {
+ this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms).");
+ }
+ else
+ {
+ switch (this.nextEpoch.State)
+ {
+ case HandshakeState.ExpectingServerHello:
+ case HandshakeState.ExpectingCertificate:
+ case HandshakeState.ExpectingServerKeyExchange:
+ case HandshakeState.ExpectingServerHelloDone:
+ this.SendClientHello(isRetransmit: true);
+ break;
+
+ case HandshakeState.ExpectingChangeCipherSpec:
+ case HandshakeState.ExpectingFinished:
+ this.SendClientKeyExchangeFlight(isRetransmit: true);
+ break;
+
+ case HandshakeState.Established:
+ default:
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ base.ResendPacketsIfNeeded();
+ }
+
+ /// <summary>
+ /// Flush any queued application data packets
+ /// </summary>
+ private void FlushQueuedApplicationData()
+ {
+ while (this.queuedApplicationData.TryDequeue(out var queuedData))
+ {
+ base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback);
+ }
+ }
+
+ /// <summary>
+ /// Request from the application to write data to the DTLS
+ /// stream. If appropriate, returns a byte span to send to
+ /// the wire.
+ /// </summary>
+ /// <param name="bytes">Plaintext bytes to write</param>
+ /// <param name="length">Length of the bytes to write</param>
+ /// <returns>
+ /// Encrypted data to put on the wire if appropriate,
+ /// otherwise an empty span
+ /// </returns>
+ private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length)
+ {
+ lock (this.syncRoot)
+ {
+ Record outgoinRecord = new Record();
+ outgoinRecord.ContentType = ContentType.ApplicationData;
+ outgoinRecord.ProtocolVersion = DtlsVersion;
+ outgoinRecord.Epoch = this.epoch;
+ outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoinRecord.Length];
+ ByteSpan writer = packet;
+ outgoinRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ new ByteSpan(bytes, 0, length).CopyTo(writer);
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoinRecord.Length),
+ packet.Slice(Record.Size, length),
+ ref outgoinRecord
+ );
+
+ return packet;
+ }
+ }
+
+ protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null)
+ {
+ lock (this.syncRoot)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.queuedApplicationData.Enqueue(new QueuedAppData
+ {
+ Bytes = data,
+ SendOption = sendOption,
+ AckCallback = ackCallback
+ });
+
+ return;
+ }
+ }
+
+ base.HandleSend(data, sendOption, ackCallback);
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnectionSync(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected internal override void HandleReceive(MessageReader reader, int bytesReceived)
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ lock (this.syncRoot)
+ {
+ this.HandleReceive(message);
+ }
+
+ reader.Recycle();
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram
+ /// </summary>
+ /// <param name="span">Bytes of the datagram</param>
+ private void HandleReceive(ByteSpan span)
+ {
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (span.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, DtlsVersion, span))
+ {
+ this.logger.WriteError("Dropping malformed record");
+ return;
+ }
+ span = span.Slice(Record.Size);
+
+ if (span.Length < record.Length)
+ {
+ this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = span.Slice(0, record.Length);
+ span = span.Slice(record.Length);
+
+ // Early out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.logger.WriteError("Dropping ApplicationData record. Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != this.epoch)
+ {
+ this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.logger.WriteWarning("Dropping duplicate record");
+ continue;
+ }
+ }
+
+ // Verify record authenticity
+ int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+
+ if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record))
+ {
+ this.logger.WriteError("Dropping non-authentic record");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update out sequence number bookkeeping
+ if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence);
+ this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})");
+ break;
+ }
+ else if (this.nextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.logger.WriteError("Dropping malformed ChangeCipherSpec message");
+ break;
+ }
+
+ // Migrate to the next epoch
+ this.epoch = this.nextEpoch.Epoch;
+ this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection;
+ this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+
+ this.nextEpoch.State = HandshakeState.ExpectingFinished;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.logger.WriteError("Dropping unsupported alert record");
+ continue;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ base.HandleReceive(reader, recordPayload.Length);
+ break;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple Handshake messages
+ while (message.Length > 0)
+ {
+ ByteSpan originalPayload = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.logger.WriteError("Dropping malformed handshake message");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ // Check for fragmented messages
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ // We only support fragmentation on Certificate messages
+ if (handshake.MessageType != HandshakeType.Certificate)
+ {
+ this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ if (message.Length < handshake.FragmentLength)
+ {
+ this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength));
+ message = message.Slice((int)handshake.FragmentLength);
+ }
+ else
+ {
+ if (message.Length < handshake.Length)
+ {
+ this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length));
+ message = message.Slice((int)handshake.Length);
+ }
+
+ ByteSpan payload = originalPayload.Slice(Handshake.Size);
+
+#if DEBUG
+ this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.HelloVerifyRequest:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 0)
+ {
+ this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ HelloVerifyRequest helloVerifyRequest;
+ if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload))
+ {
+ this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ // If the cookie differs, save it and restart the handshake
+ if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length
+ && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1)
+ {
+ this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length];
+ helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie);
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+
+ // We don't need to resend here. We already have the cookie so we already sent it once.
+ this.SendClientHello(isRetransmit: false);
+
+ break;
+
+ case HandshakeType.ServerHello:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 1)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ServerHello serverHello;
+ if (!ServerHello.Parse(out serverHello, payload))
+ {
+ this.logger.WriteError("Dropping malformed ServerHello message");
+ continue;
+ }
+
+ switch (serverHello.CipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random);
+ break;
+
+ default:
+ this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})");
+ continue;
+ }
+
+ // Save server parameters
+ this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite;
+ serverHello.Random.CopyTo(this.nextEpoch.ServerRandom);
+ this.nextEpoch.State = HandshakeState.ExpectingCertificate;
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+#if DEBUG
+ this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}");
+#endif
+
+ // Append ServerHelllo message to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.Certificate:
+ if (this.nextEpoch.State != HandshakeState.ExpectingCertificate)
+ {
+ this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 2)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // If this is a fragmented message
+ if (handshake.FragmentLength != handshake.Length)
+ {
+ if (this.nextEpoch.CertificatePayload.Length != handshake.Length)
+ {
+ this.nextEpoch.CertificatePayload = new byte[handshake.Length];
+ this.nextEpoch.CertificateFragments.Clear();
+ }
+
+ // Should we add this fragment?
+ // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments
+ // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt?
+ // So let's just ignore that and assume that the sender always wants to send the same fragments.
+ if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength))
+ {
+ continue;
+ }
+
+ payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength));
+ this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength });
+ this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => {
+ return lhs.Offset.CompareTo(rhs.Offset);
+ });
+
+ // Have we completed the message?
+ int currentOffset = 0;
+ bool valid = true;
+ foreach (FragmentRange range in this.nextEpoch.CertificateFragments)
+ {
+ if (range.Offset != currentOffset)
+ {
+ valid = false;
+ break;
+ }
+
+ currentOffset += range.Length;
+ }
+
+ if (currentOffset != this.nextEpoch.CertificatePayload.Length)
+ {
+ valid = false;
+ }
+
+ // Still waiting on more fragments?
+ if (!valid)
+ {
+ continue;
+ }
+
+ // Replace the message payload, and continue
+ this.nextEpoch.CertificateFragments.Clear();
+ payload = this.nextEpoch.CertificatePayload;
+ }
+
+ X509Certificate2 certificate;
+ if (!Certificate.Parse(out certificate, payload))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message");
+ continue;
+ }
+
+ // Verify the certificate is authenticate
+ if (!this.serverCertificates.Contains(certificate))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic");
+ continue;
+ }
+
+ RSA publicKey = certificate.PublicKey.Key as RSA;
+ if (publicKey == null)
+ {
+ this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed");
+ continue;
+ }
+
+ // Add the final Certificate message to the verification stream
+ Handshake fullCertificateHandhake = handshake;
+ fullCertificateHandhake.FragmentOffset = 0;
+ fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length;
+
+ ByteSpan serializedCertificateHandshake = new byte[Handshake.Size];
+ fullCertificateHandhake.Encode(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(payload);
+
+ this.nextEpoch.ServerPublicKey = publicKey;
+ this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange;
+ break;
+
+ case HandshakeType.ServerKeyExchange:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (this.nextEpoch.ServerPublicKey == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key");
+ continue;
+ }
+ else if (this.nextEpoch.Handshake == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface");
+ continue;
+ }
+ else if (handshake.MessageSequence != 3)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()];
+ if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey))
+ {
+ this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage");
+ return false;
+ }
+
+ // Generate the session master secret
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ this.nextEpoch.ClientRandom.CopyTo(randomSeed);
+ this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create record protection for the upcoming epoch
+ switch (this.nextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , this.nextEpoch.ServerRandom
+ , this.nextEpoch.ClientRandom
+ );
+ break;
+
+ default:
+ ///NOTE(mendsley): this _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "SeverHello processing already approved this ciphersuite");
+
+ this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection");
+ return false;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone;
+ this.nextEpoch.MasterSecret = masterSecret;
+
+ // Append ServerKeyExchange to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.ServerHelloDone:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 4)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+
+ // Append ServerHelloDone to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+
+ this.SendClientKeyExchangeFlight(isRetransmit: false);
+ break;
+
+ case HandshakeType.Finished:
+ if (this.nextEpoch.State != HandshakeState.ExpectingFinished)
+ {
+ this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (payload.Length != Finished.Size)
+ {
+ this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 7)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the digest from the server
+ if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification))
+ {
+ this.logger.WriteError("Dropping non-verified Finished handshake message");
+ return false;
+ }
+
+ ++this.nextEpoch.Epoch;
+ this.nextEpoch.State = HandshakeState.Established;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+
+ this.FlushQueuedApplicationData();
+ break;
+
+ // Drop messages we do not support
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.HelloRequest:
+ this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})");
+ break;
+
+ // Drop messages that originate from the client
+ case HandshakeType.ClientHello:
+ case HandshakeType.ClientKeyExchange:
+ case HandshakeType.CertificateVerify:
+ this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})");
+ break;
+ }
+ }
+
+ return true;
+ }
+
+ private bool IsFragmentOverlapping(List<FragmentRange> fragments, uint newOffset, uint newLength)
+ {
+ foreach (var frag in fragments)
+ {
+ // New fragment overlaps an existing one
+ if (newOffset <= frag.Offset
+ && frag.Offset < newOffset + newLength)
+ {
+ return true;
+ }
+
+ // Existing fragment overlaps this new one
+ if (frag.Offset <= newOffset
+ && newOffset < frag.Offset + frag.Length)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Send (resend) a ClientHello message to the server
+ /// </summary>
+ protected virtual void SendClientHello(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}");
+#endif
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.ClientProtocolVersion = DtlsVersion;
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion);
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ clientHello.Encode(writer);
+
+ // If this is our first valid attempt at contacting the server:
+ // - Reset our verification stream
+ // - Write ClientHello to the verification stream
+ // - We next expect a ServerHello
+ //
+ // ClientHello+Cookie triggers many sequential packets in response
+ // It's important to make forward progress as the packets may be reordered in-flight
+ // But with enough resends, we will read them all in an appropriate order
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length)
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ }
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected void Test_SendClientHello(Func<ClientHello, ByteSpan, ByteSpan> encodeCallback)
+ {
+ // Reset our verification stream
+ this.nextEpoch.VerificationStream.Reset();
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ writer = encodeCallback(clientHello, writer);
+
+ // Write ClientHello to the verification stream
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ /// <summary>
+ /// Send (resend) the ClientKeyExchange flight
+ /// </summary>
+ /// <param name="isRetransmit">
+ /// True if this is a retransmit of the flight. Otherwise,
+ /// false
+ /// </param>
+ protected virtual void SendClientKeyExchangeFlight(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}");
+#endif
+ if (this.nextEpoch.State == HandshakeState.Established)
+ {
+ return;
+ }
+
+ // Describe our flight
+ Handshake keyExchangeHandshake = new Handshake();
+ keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange;
+ keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize();
+ keyExchangeHandshake.MessageSequence = 5;
+ keyExchangeHandshake.FragmentOffset = 0;
+ keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length;
+
+ Record keyExchangeRecord = new Record();
+ keyExchangeRecord.ContentType = ContentType.Handshake;
+ keyExchangeRecord.ProtocolVersion = DtlsVersion;
+ keyExchangeRecord.Epoch = this.epoch;
+ keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = DtlsVersion;
+ changeCipherSpecRecord.Epoch = this.epoch;
+ changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Handshake finishedHandshake = new Handshake();
+ finishedHandshake.MessageType = HandshakeType.Finished;
+ finishedHandshake.Length = Finished.Size;
+ finishedHandshake.MessageSequence = 6;
+ finishedHandshake.FragmentOffset = 0;
+ finishedHandshake.FragmentLength = finishedHandshake.Length;
+
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = DtlsVersion;
+ finishedRecord.Epoch = this.nextEpoch.Epoch;
+ finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length);
+ ++this.nextEpoch.NextOutgoingSequence;
+
+ // Encode flight to wire format
+ int packetLength = 0
+ + Record.Size + keyExchangeRecord.Length
+ + Record.Size + changeCipherSpecRecord.Length
+ + Record.Size + finishedRecord.Length;
+ ;
+ ByteSpan packet = new byte[packetLength];
+ ByteSpan writer = packet;
+
+ keyExchangeRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ keyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer);
+
+ ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length);
+ writer = startOfChangeCipherSpecRecord;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+ writer = writer.Slice(ChangeCipherSpec.Size);
+
+ ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ finishedHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ // Interject here to writer our client key exchange
+ // message into the verification stream
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)keyExchangeHandshake.Length
+ )
+ );
+ }
+
+ // Calculate the hash of the verification stream
+ ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize];
+ this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash);
+
+ // Expand our master secret into Finished digests for the client and server
+ PrfSha256.ExpandSecret(
+ this.nextEpoch.ServerVerification
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeHash
+ );
+
+ PrfSha256.ExpandSecret(
+ writer.Slice(0, Finished.Size)
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeHash
+ );
+ writer = writer.Slice(Finished.Size);
+
+ // Protect the ClientKeyExchange record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, keyExchangeRecord.Length),
+ packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length),
+ ref keyExchangeRecord
+ );
+
+ // Protect the ChangeCipherSpec record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length),
+ startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished record
+ this.nextEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length),
+ ref finishedRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+#if DEBUG
+ if (DropClientKeyExchangeFlight())
+ {
+ return;
+ }
+#endif
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected virtual bool DropClientKeyExchangeFlight()
+ {
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
new file mode 100644
index 0000000..f840053
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
@@ -0,0 +1,734 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Handshake message type
+ /// </summary>
+ public enum HandshakeType : byte
+ {
+ HelloRequest = 0,
+ ClientHello = 1,
+ ServerHello = 2,
+ HelloVerifyRequest = 3,
+ Certificate = 11,
+ ServerKeyExchange = 12,
+ CertificateRequest = 13,
+ ServerHelloDone = 14,
+ CertificateVerify = 15,
+ ClientKeyExchange = 16,
+ Finished = 20,
+ }
+
+ /// <summary>
+ /// List of cipher suites
+ /// </summary>
+ public enum CipherSuite
+ {
+ TLS_NULL_WITH_NULL_NULL = 0x0000,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F,
+ }
+
+ /// <summary>
+ /// List of compression methods
+ /// </summary>
+ public enum CompressionMethod : byte
+ {
+ Null = 0,
+ }
+
+ /// <summary>
+ /// Extension type
+ /// </summary>
+ public enum ExtensionType : ushort
+ {
+ EllipticCurves = 10,
+ }
+
+ /// <summary>
+ /// Named curves
+ /// </summary>
+ public enum NamedCurve : ushort
+ {
+ Reserved = 0,
+ secp256r1 = 23,
+ x25519 = 29,
+ }
+
+ /// <summary>
+ /// Elliptic curve type
+ /// </summary>
+ public enum ECCurveType : byte
+ {
+ NamedCurve = 3,
+ }
+
+ /// <summary>
+ /// Hash algorithms
+ /// </summary>
+ public enum HashAlgorithm : byte
+ {
+ None = 0,
+ Sha256 = 4,
+ }
+
+ /// <summary>
+ /// Signature algorithms
+ /// </summary>
+ public enum SignatureAlgorithm : byte
+ {
+ Anonymous = 0,
+ RSA = 1,
+ ECDSA = 3,
+ }
+
+ /// <summary>
+ /// Random state for entropy
+ /// </summary>
+ public struct Random
+ {
+ public const int Size = 0
+ + 4 // gmt_unix_time
+ + 28 // random_bytes
+ ;
+ }
+
+ /// <summary>
+ /// Encode/decode handshake protocol header
+ /// </summary>
+ public struct Handshake
+ {
+ public HandshakeType MessageType;
+ public uint Length;
+ public ushort MessageSequence;
+ public uint FragmentOffset;
+ public uint FragmentLength;
+
+ public const int Size = 12;
+
+ /// <summary>
+ /// Parse a Handshake protocol header from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode a handshake header. Otherwise false</returns>
+ public static bool Parse(out Handshake header, ByteSpan span)
+ {
+ header = new Handshake();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ header.MessageType = (HandshakeType)span[0];
+ header.Length = span.ReadBigEndian24(1);
+ header.MessageSequence = span.ReadBigEndian16(4);
+ header.FragmentOffset = span.ReadBigEndian24(6);
+ header.FragmentLength = span.ReadBigEndian24(9);
+ return true;
+ }
+
+ /// <summary>
+ /// Encode the Handshake protocol header to wire format
+ /// </summary>
+ /// <param name="span"></param>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.MessageType;
+ span.WriteBigEndian24(this.Length, 1);
+ span.WriteBigEndian16(this.MessageSequence, 4);
+ span.WriteBigEndian24(this.FragmentOffset, 6);
+ span.WriteBigEndian24(this.FragmentLength, 9);
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode ClientHello Handshake message
+ /// </summary>
+ public struct ClientHello
+ {
+ public ProtocolVersion ClientProtocolVersion;
+ public ByteSpan Random;
+ public ByteSpan Cookie;
+ public HazelDtlsSessionInfo Session;
+ public ByteSpan CipherSuites;
+ public ByteSpan SupportedCurves;
+
+ public const int MinSize = 0
+ + 2 // client_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 1 // cookie (size)
+ + 2 // cipher_suites (size)
+ + 1 // compression_methods (size)
+ + 1 // compression_method[0] (NULL)
+
+ + 2 // extensions size
+
+ + 0 // NamedCurveList extensions[0]
+ + 2 // extensions[0].extension_type
+ + 2 // extensions[0].extension_data (length)
+ + 2 // extensions[0].named_curve_list (size)
+ ;
+
+ /// <summary>
+ /// Calculate the size in bytes required for the ClientHello payload
+ /// </summary>
+ /// <returns></returns>
+ public int CalculateSize()
+ {
+ return MinSize
+ + this.Session.PayloadSize
+ + this.Cookie.Length
+ + this.CipherSuites.Length
+ + this.SupportedCurves.Length
+ ;
+ }
+
+ /// <summary>
+ /// Parse a Handshake ClientHello payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the ClientHello message. Otherwise false</returns>
+ public static bool Parse(out ClientHello result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new ClientHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ClientProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ if (expectedProtocolVersion.HasValue && result.ClientProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ byte cookieSize = span[0];
+ if (span.Length < 1 + cookieSize)
+ {
+ return false;
+ }
+ result.Cookie = span.Slice(1, cookieSize);
+ span = span.Slice(1 + cookieSize);
+
+ ushort cipherSuiteSize = span.ReadBigEndian16();
+ if (span.Length < 2 + cipherSuiteSize)
+ {
+ return false;
+ }
+ else if (cipherSuiteSize % 2 != 0)
+ {
+ return false;
+ }
+ result.CipherSuites = span.Slice(2, cipherSuiteSize);
+ span = span.Slice(2 + cipherSuiteSize);
+
+ int compressionMethodsSize = span[0];
+ bool foundNullCompressionMethod = false;
+ for (int ii = 0; ii != compressionMethodsSize; ++ii)
+ {
+ if (span[1+ii] == (byte)CompressionMethod.Null)
+ {
+ foundNullCompressionMethod = true;
+ break;
+ }
+ }
+
+ if (!foundNullCompressionMethod
+ || span.Length < 1 + compressionMethodsSize)
+ {
+ return false;
+ }
+
+ span = span.Slice(1 + compressionMethodsSize);
+
+ // Parse extensions
+ if (span.Length > 0)
+ {
+ if (span.Length < 2)
+ {
+ return false;
+ }
+
+ ushort extensionsSize = span.ReadBigEndian16();
+ span = span.Slice(2);
+ if (span.Length != extensionsSize)
+ {
+ return false;
+ }
+
+ while (span.Length > 0)
+ {
+ // Parse extension header
+ if (span.Length < 4)
+ {
+ return false;
+ }
+
+ ExtensionType extensionType = (ExtensionType)span.ReadBigEndian16(0);
+ ushort extensionLength = span.ReadBigEndian16(2);
+
+ if (span.Length < 4 + extensionLength)
+ {
+ return false;
+ }
+
+ ByteSpan extensionData = span.Slice(4, extensionLength);
+ span = span.Slice(4 + extensionLength);
+ result.ParseExtension(extensionType, extensionData);
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Decode a ClientHello extension
+ /// </summary>
+ /// <param name="extensionType">Extension type</param>
+ /// <param name="extensionData">Extension data</param>
+ private void ParseExtension(ExtensionType extensionType, ByteSpan extensionData)
+ {
+ switch (extensionType)
+ {
+ case ExtensionType.EllipticCurves:
+ if (extensionData.Length % 2 != 0)
+ {
+ break;
+ }
+ else if (extensionData.Length < 2)
+ {
+ break;
+ }
+
+ ushort namedCurveSize = extensionData.ReadBigEndian16(0);
+ if (namedCurveSize % 2 != 0)
+ {
+ break;
+ }
+
+ this.SupportedCurves = extensionData.Slice(2, namedCurveSize);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified cipher suite
+ /// </summary>
+ public bool ContainsCipherSuite(CipherSuite cipherSuite)
+ {
+ ByteSpan iterator = this.CipherSuites;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)cipherSuite)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified curve
+ /// </summary>
+ public bool ContainsCurve(NamedCurve curve)
+ {
+ ByteSpan iterator = this.SupportedCurves;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)curve)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Encode Handshake ClientHello payload to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span.WriteBigEndian16((ushort)this.ClientProtocolVersion);
+ span = span.Slice(2);
+
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span[0] = (byte)this.Cookie.Length;
+ this.Cookie.CopyTo(span.Slice(1));
+ span = span.Slice(1 + this.Cookie.Length);
+
+ span.WriteBigEndian16((ushort)this.CipherSuites.Length);
+ this.CipherSuites.CopyTo(span.Slice(2));
+ span = span.Slice(2 + this.CipherSuites.Length);
+
+ span[0] = 1;
+ span[1] = (byte)CompressionMethod.Null;
+ span = span.Slice(2);
+
+ // Extensions size
+ span.WriteBigEndian16((ushort)(6 + this.SupportedCurves.Length));
+ span = span.Slice(2);
+
+ // Supported curves extension
+ span.WriteBigEndian16((ushort)ExtensionType.EllipticCurves);
+ span.WriteBigEndian16((ushort)(2 + this.SupportedCurves.Length), 2);
+ span.WriteBigEndian16((ushort)this.SupportedCurves.Length, 4);
+ this.SupportedCurves.CopyTo(span.Slice(6));
+ }
+ }
+
+ /// <summary>
+ /// Encode/Decode session information in ClientHello
+ /// </summary>
+ public struct HazelDtlsSessionInfo
+ {
+ public const byte CurrentClientSessionSize = 1;
+ public const byte CurrentClientSessionVersion = 1;
+
+ public byte FullSize => (byte)(1 + this.PayloadSize);
+ public byte PayloadSize;
+ public byte Version;
+
+ public HazelDtlsSessionInfo(byte version)
+ {
+ this.Version = version;
+ switch (version)
+ {
+ case 0: // Does not write version byte
+ this.PayloadSize = 0;
+ return;
+ case 1: // Writes version byte only
+ this.PayloadSize = 1;
+ return;
+ }
+
+ throw new ArgumentOutOfRangeException("Unimplemented Hazel session version");
+ }
+
+ public void Encode(ByteSpan writer)
+ {
+ writer[0] = this.PayloadSize;
+
+ if (this.Version > 0)
+ {
+ writer[1] = this.Version;
+ }
+ }
+
+ public static bool Parse(out HazelDtlsSessionInfo result, ByteSpan reader)
+ {
+ result = new HazelDtlsSessionInfo();
+ if (reader.Length < 1)
+ {
+ return false;
+ }
+
+ result.PayloadSize = reader[0];
+
+ // Back compat, length may be zero, version defaults to 0.
+ if (result.PayloadSize == 0)
+ {
+ result.Version = 0;
+ return true;
+ }
+
+ // Forward compat, if length > 1, ignore the rest
+ result.Version = reader[1];
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake HelloVerifyRequest message
+ /// </summary>
+ public struct HelloVerifyRequest
+ {
+ public const int CookieSize = 20;
+ public const int Size = 0
+ + 2 // server_version
+ + 1 // cookie (size)
+ + CookieSize // cookie (data)
+ ;
+
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Cookie;
+
+ /// <summary>
+ /// Parse a Handshake HelloVerifyRequest payload from wire
+ /// format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the HelloVerifyRequest
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out HelloVerifyRequest result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new HelloVerifyRequest();
+ if (span.Length < 3)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(0);
+ if (expectedProtocolVersion.HasValue && result.ServerProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ byte cookieSize = span[2];
+ span = span.Slice(3);
+
+ if (span.Length < cookieSize)
+ {
+ return false;
+ }
+
+ result.Cookie = span;
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a HelloVerifyRequest payload to wire format
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static void Encode(ByteSpan span, EndPoint peerAddress, HMAC hmac, ProtocolVersion protocolVersion)
+ {
+ ByteSpan cookie = ComputeAddressMac(peerAddress, hmac);
+
+ span.WriteBigEndian16((ushort)protocolVersion);
+ span[2] = (byte)CookieSize;
+ cookie.CopyTo(span.Slice(3));
+ }
+
+ /// <summary>
+ /// Generate an HMAC for a peer address
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static ByteSpan ComputeAddressMac(EndPoint peerAddress, HMAC hmac)
+ {
+ SocketAddress address = peerAddress.Serialize();
+ byte[] data = new byte[address.Size];
+ for (int ii = 0, nn = data.Length; ii != nn; ++ii)
+ {
+ data[ii] = address[ii];
+ }
+
+ ///NOTE(mendsley): Lame that we need to allocate+copy here
+ ByteSpan signature = hmac.ComputeHash(data);
+ return signature.Slice(0, CookieSize);
+ }
+
+ /// <summary>
+ /// Verify a client's cookie was signed by our listener
+ /// </summary>
+ /// <param name="cookie">Wire format cookie</param>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ /// <returns>True if the cookie is valid. Otherwise false</returns>
+ public static bool VerifyCookie(ByteSpan cookie, EndPoint peerAddress, HMAC hmac)
+ {
+ if (cookie.Length != CookieSize)
+ {
+ return false;
+ }
+
+ ByteSpan expectedHash = ComputeAddressMac(peerAddress, hmac);
+ if (expectedHash.Length != cookie.Length)
+ {
+ return false;
+ }
+
+ return (1 == Crypto.Const.ConstantCompareSpans(cookie, expectedHash));
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake ServerHello message
+ /// </summary>
+ public struct ServerHello
+ {
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Random;
+ public CipherSuite CipherSuite;
+ public HazelDtlsSessionInfo Session;
+
+ public const int MinSize = 0
+ + 2 // server_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 2 // cipher_suite
+ + 1 // compression_method
+ ;
+
+ public int Size => MinSize + Session.PayloadSize;
+
+ /// <summary>
+ /// Parse a Handshake ServerHello payload from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the ServerHello
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out ServerHello result, ByteSpan span)
+ {
+ result = new ServerHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ result.CipherSuite = (CipherSuite)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ CompressionMethod compressionMethod = (CompressionMethod)span[0];
+ if (compressionMethod != CompressionMethod.Null)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode Handshake ServerHello to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+
+ span.WriteBigEndian16((ushort)this.ServerProtocolVersion, 0);
+ span = span.Slice(2);
+
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span.WriteBigEndian16((ushort)this.CipherSuite);
+ span = span.Slice(2);
+
+ span[0] = (byte)CompressionMethod.Null;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Certificate message
+ /// </summary>
+ public struct Certificate
+ {
+ /// <summary>
+ /// Encode a certificate to wire formate
+ /// </summary>
+ public static ByteSpan Encode(X509Certificate2 certificate)
+ {
+ ByteSpan certData = certificate.GetRawCertData();
+ int totalSize = certData.Length + 3 + 3;
+
+ ByteSpan result = new byte[totalSize];
+
+ ByteSpan writer = result;
+ writer.WriteBigEndian24((uint)certData.Length + 3);
+ writer = writer.Slice(3);
+ writer.WriteBigEndian24((uint)certData.Length);
+ writer = writer.Slice(3);
+
+ certData.CopyTo(writer);
+ return result;
+ }
+
+ /// <summary>
+ /// Parse a Handshake Certificate payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the Certificate message. Otherwise false</returns>
+ public static bool Parse(out X509Certificate2 certificate, ByteSpan span)
+ {
+ certificate = null;
+ if (span.Length < 6)
+ {
+ return false;
+ }
+
+ uint totalSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+
+ if (span.Length < totalSize)
+ {
+ return false;
+ }
+
+ uint certificateSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+ if (span.Length < certificateSize)
+ {
+ return false;
+ }
+
+ byte[] rawData = new byte[certificateSize];
+ span.CopyTo(rawData, 0);
+ try
+ {
+ certificate = new X509Certificate2(rawData);
+ }
+ catch (Exception)
+ {
+ return false;
+ }
+
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Finished message
+ /// </summary>
+ public struct Finished
+ {
+ public const int Size = 12;
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
new file mode 100644
index 0000000..eedd977
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
@@ -0,0 +1,63 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for the handshake portion of
+ /// the connection.
+ /// </summary>
+ public interface IHandshakeCipherSuite : IDisposable
+ {
+ /// <summary>
+ /// Gets the size of the shared key
+ /// </summary>
+ /// <returns>Size of the shared key in bytes </returns>
+ int SharedKeySize();
+
+ /// <summary>
+ /// Calculate the size of the ServerKeyExchnage message
+ /// </summary>
+ /// <param name="privateKey">
+ /// Private key that will be used to sign the message
+ /// </param>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateServerMessageSize(object privateKey);
+
+ /// <summary>
+ /// Encodes the ServerKeyExchange message
+ /// </summary>
+ /// <param name="privateKey">Private key to use for signing</param>
+ void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey);
+
+ /// <summary>
+ /// Verifies the authenticity of a server key exchange
+ /// message and calculates the shared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the authenticity has been validated and a shared key
+ /// was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey);
+
+ /// <summary>
+ /// Calculate the size of the ClientKeyExchange message
+ /// </summary>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateClientMessageSize();
+
+ /// <summary>
+ /// Encodes the ClientKeyExchangeMessage
+ /// </summary>
+ void EncodeClientKeyExchangeMessage(ByteSpan output);
+
+ /// <summary>
+ /// Verifies the validity of a client key exchange message
+ /// and calculats the hsared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the client exchange message is valid and a
+ /// shared key was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage);
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
new file mode 100644
index 0000000..cbee1b0
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
@@ -0,0 +1,84 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for protection of record payload.
+ /// </summary>
+ public interface IRecordProtection : IDisposable
+ {
+ /// <summary>
+ /// Calculate the size of an encrypted plaintext
+ /// </summary>
+ /// <param name="dataSize">Size of plaintext in bytes</param>
+ /// <returns>Size of encrypted ciphertext in bytes</returns>
+ int GetEncryptedSize(int dataSize);
+
+ /// <summary>
+ /// Calculate the size of decrypted ciphertext
+ /// </summary>
+ /// <param name="dataSize">Size of ciphertext in bytes</param>
+ /// <returns>Size of decrypted plaintext in bytes</returns>
+ int GetDecryptedSize(int dataSize);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record);
+ }
+
+ /// <summary>
+ /// Factory to create record protection from cipher suite identifiers
+ /// </summary>
+ public sealed class RecordProtectionFactory
+ {
+ public static IRecordProtection Create(CipherSuite cipherSuite, ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ switch (cipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ return new Aes128GcmRecordProtection(masterSecret, serverRandom, clientRandom);
+
+ default:
+ return null;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
new file mode 100644
index 0000000..76fa132
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Passthrough record protection implementaion
+ /// </summary>
+ public class NullRecordProtection : IRecordProtection
+ {
+ public readonly static NullRecordProtection Instance = new NullRecordProtection();
+
+ public void Dispose()
+ {
+ }
+
+ public int GetEncryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public int GetDecryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ private static void CopyMaybeOverlappingSpans(ByteSpan output, ByteSpan input)
+ {
+ // Early out if the ranges `output` is equal to `input`
+ if (output.GetUnderlyingArray() == input.GetUnderlyingArray())
+ {
+ if (output.Offset == input.Offset && output.Length == input.Length)
+ {
+ return;
+ }
+ }
+
+ input.CopyTo(output);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
new file mode 100644
index 0000000..1fa7f17
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
@@ -0,0 +1,84 @@
+using System.Text;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Common Psuedorandom Function labels for TLS
+ /// </summary>
+ public struct PrfLabel
+ {
+ public static readonly ByteSpan MASTER_SECRET = LabelToBytes("master secert");
+ public static readonly ByteSpan KEY_EXPANSION = LabelToBytes("key expansion");
+ public static readonly ByteSpan CLIENT_FINISHED = LabelToBytes("client finished");
+ public static readonly ByteSpan SERVER_FINISHED = LabelToBytes("server finished");
+
+ /// <summary>
+ /// Convert a text label to a byte sequence
+ /// </summary>
+ public static ByteSpan LabelToBytes(string label)
+ {
+ return Encoding.ASCII.GetBytes(label);
+ }
+ }
+
+ /// <summary>
+ /// The P_SHA256 Psuedorandom Function
+ /// </summary>
+ public struct PrfSha256
+ {
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, string label, ByteSpan initialSeed)
+ {
+ ExpandSecret(output, key, PrfLabel.LabelToBytes(label), initialSeed);
+ }
+
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, ByteSpan label, ByteSpan initialSeed)
+ {
+ ByteSpan writer = output;
+
+ byte[] roundSeed = new byte[label.Length + initialSeed.Length];
+ label.CopyTo(roundSeed);
+ initialSeed.CopyTo(roundSeed, label.Length);
+
+ byte[] hashA = roundSeed;
+
+ using (HMACSHA256 hmac = new HMACSHA256(key.ToArray()))
+ {
+ byte[] input = new byte[hmac.OutputBlockSize + roundSeed.Length];
+ new ByteSpan(roundSeed).CopyTo(input, hmac.OutputBlockSize);
+
+ while (writer.Length > 0)
+ {
+ // Update hashA
+ hashA = hmac.ComputeHash(hashA);
+
+ // generate hash input
+ new ByteSpan(hashA).CopyTo(input);
+
+ ByteSpan roundOutput = hmac.ComputeHash(input);
+ if (roundOutput.Length > writer.Length)
+ {
+ roundOutput = roundOutput.Slice(0, writer.Length);
+ }
+
+ roundOutput.CopyTo(writer);
+ writer = writer.Slice(roundOutput.Length);
+ }
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Record.cs b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
new file mode 100644
index 0000000..23eaa95
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
@@ -0,0 +1,123 @@
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS version constants
+ /// </summary>
+ public enum ProtocolVersion : ushort
+ {
+ /// <summary>
+ /// Use to obfuscate DTLS as regular UDP packets
+ /// </summary>
+ UDP = 0,
+
+ /// <summary>
+ /// DTLS 1.2
+ /// </summary>
+ DTLS1_2 = 0xFEFD,
+ }
+
+ /// <summary>
+ /// DTLS record content type
+ /// </summary>
+ public enum ContentType : byte
+ {
+ ChangeCipherSpec = 20,
+ Alert = 21,
+ Handshake = 22,
+ ApplicationData = 23,
+ }
+
+ /// <summary>
+ /// Encode/decode DTLS record header
+ /// </summary>
+ public struct Record
+ {
+ public ContentType ContentType;
+ public ProtocolVersion ProtocolVersion;
+ public ushort Epoch;
+ public ulong SequenceNumber;
+ public ushort Length;
+
+ public const int Size = 13;
+
+ /// <summary>
+ /// Parse a DTLS record from wire format
+ /// </summary>
+ /// <returns>True if we successfully parse the record header. Otherwise false</returns>
+ public static bool Parse(out Record record, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ record = new Record();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ record.ContentType = (ContentType)span[0];
+ record.ProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(1);
+ record.Epoch = span.ReadBigEndian16(3);
+ record.SequenceNumber = span.ReadBigEndian48(5);
+ record.Length = span.ReadBigEndian16(11);
+
+ if (expectedProtocolVersion.HasValue && record.ProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a DTLS record to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.ContentType;
+ span.WriteBigEndian16((ushort)this.ProtocolVersion, 1);
+ span.WriteBigEndian16(this.Epoch, 3);
+ span.WriteBigEndian48(this.SequenceNumber, 5);
+ span.WriteBigEndian16(this.Length, 11);
+ }
+ }
+
+ public struct ChangeCipherSpec
+ {
+ public const int Size = 1;
+
+ enum Value : byte
+ {
+ ChangeCipherSpec = 1,
+ }
+
+ /// <summary>
+ /// Parse a ChangeCipherSpec record from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully parse the ChangeCipherSpec
+ /// record. Otherwise, false.
+ /// </returns>
+ public static bool Parse(ByteSpan span)
+ {
+ if (span.Length != 1)
+ {
+ return false;
+ }
+
+ Value value = (Value)span[0];
+ if (value != Value.ChangeCipherSpec)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a ChangeCipherSpec record to wire format
+ /// </summary>
+ public static void Encode(ByteSpan span)
+ {
+ span[0] = (byte)Value.ChangeCipherSpec;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
new file mode 100644
index 0000000..38da061
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
@@ -0,0 +1,159 @@
+using System;
+using System.Collections.Concurrent;
+using System.Security.Cryptography;
+using System.Threading;
+
+namespace Hazel.Dtls
+{
+ internal class ThreadedHmacHelper : IDisposable
+ {
+ private class ThreadHmacs
+ {
+ public HMAC currentHmac;
+ public HMAC previousHmac;
+ public HMAC hmacToDispose;
+ }
+
+ private static readonly int CookieHmacRotationTimeout = (int)TimeSpan.FromHours(1.0).TotalMilliseconds;
+
+ private readonly ILogger logger;
+ private readonly ConcurrentDictionary<int, ThreadHmacs> hmacs;
+ private Timer rotateKeyTimer;
+ private RandomNumberGenerator cryptoRandom;
+ private byte[] currentHmacKey;
+
+ public ThreadedHmacHelper(ILogger logger)
+ {
+ this.hmacs = new ConcurrentDictionary<int, ThreadHmacs>();
+ this.rotateKeyTimer = new Timer(RotateKeys, null, CookieHmacRotationTimeout, CookieHmacRotationTimeout);
+ this.cryptoRandom = RandomNumberGenerator.Create();
+
+ this.logger = logger;
+ SetHmacKey();
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the current cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetCurrentCookieHmacsForThread()
+ {
+ return GetHmacsForThread().currentHmac;
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the previous cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetPreviousCookieHmacsForThread()
+ {
+ return GetHmacsForThread().previousHmac;
+ }
+
+ public void Dispose()
+ {
+ ManualResetEvent signalRotateKeyTimerEnded = new ManualResetEvent(false);
+ this.rotateKeyTimer.Dispose(signalRotateKeyTimerEnded);
+ signalRotateKeyTimerEnded.WaitOne();
+ signalRotateKeyTimerEnded.Dispose();
+ signalRotateKeyTimerEnded = null;
+ this.rotateKeyTimer = null;
+
+ this.cryptoRandom.Dispose();
+ this.cryptoRandom = null;
+
+ foreach (var threadIdToHmac in this.hmacs)
+ {
+ ThreadHmacs threadHmacs = threadIdToHmac.Value;
+ threadHmacs.currentHmac?.Dispose();
+ threadHmacs.currentHmac = null;
+ threadHmacs.previousHmac?.Dispose();
+ threadHmacs.previousHmac = null;
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = null;
+ }
+
+ this.hmacs.Clear();
+ }
+
+ private ThreadHmacs GetHmacsForThread()
+ {
+ int threadId = Thread.CurrentThread.ManagedThreadId;
+
+ if (!this.hmacs.TryGetValue(threadId, out ThreadHmacs threadHmacs))
+ {
+ threadHmacs = CreateNewThreadHmacs();
+
+ if (!this.hmacs.TryAdd(threadId, threadHmacs))
+ {
+ this.logger.WriteError($"Cannot add threadHmacs for thread {threadId} during GetHmacsForThread! Should never happen!");
+ }
+ }
+
+ return threadHmacs;
+ }
+
+ /// <summary>
+ /// Rotates the hmacs of all active threads
+ /// </summary>
+ private void RotateKeys(object _)
+ {
+ SetHmacKey();
+
+ foreach (var threadIds in this.hmacs)
+ {
+ RotateKey(threadIds.Key);
+ }
+ }
+
+ /// <summary>
+ /// Rotate hmacs of single thread
+ /// </summary>
+ /// <param name="threadId">Managed thread Id of thread calling this method.</param>
+ private void RotateKey(int threadId)
+ {
+ ThreadHmacs threadHmacs;
+
+ if (!this.hmacs.TryGetValue(threadId, out threadHmacs))
+ {
+ this.logger.WriteError($"Cannot find thread {threadId} in hmacs during rotation! Should never happen!");
+ return;
+ }
+
+ // No thread should still have a reference to hmacToDispose, which should now have a lifetime of > 1 hour
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = threadHmacs.previousHmac;
+ threadHmacs.previousHmac = threadHmacs.currentHmac;
+ threadHmacs.currentHmac = CreateNewCookieHMAC();
+ }
+
+ private ThreadHmacs CreateNewThreadHmacs()
+ {
+ return new ThreadHmacs
+ {
+ previousHmac = CreateNewCookieHMAC(),
+ currentHmac = CreateNewCookieHMAC()
+ };
+ }
+
+ /// <summary>
+ /// Create a new cookie HMAC signer
+ /// </summary>
+ private HMAC CreateNewCookieHMAC()
+ {
+ const string HMACProvider = "System.Security.Cryptography.HMACSHA1";
+ HMAC hmac = HMAC.Create(HMACProvider);
+ hmac.Key = this.currentHmacKey;
+ return hmac;
+ }
+
+ /// <summary>
+ /// Creates a new cryptographically secure random Hmac key
+ /// </summary>
+ private void SetHmacKey()
+ {
+ // MSDN recommends 64 bytes key for HMACSHA-1
+ byte[] newKey = new byte[64];
+ this.cryptoRandom.GetBytes(newKey);
+ this.currentHmacKey = newKey;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
new file mode 100644
index 0000000..f567252
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
@@ -0,0 +1,202 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// ECDHE_RSA_*_256 cipher suite
+ /// </summary>
+ public class X25519EcdheRsaSha256 : IHandshakeCipherSuite
+ {
+ private readonly ByteSpan privateAgreementKey;
+ private SHA256 sha256 = SHA256.Create();
+
+ /// <summary>
+ /// Create a new instance of the x25519 key exchange
+ /// </summary>
+ /// <param name="random">Random data source</param>
+ public X25519EcdheRsaSha256(RandomNumberGenerator random)
+ {
+ byte[] buffer = new byte[X25519.KeySize];
+ random.GetBytes(buffer);
+ this.privateAgreementKey = buffer;
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.sha256?.Dispose();
+ this.sha256 = null;
+ }
+
+ /// <inheritdoc />
+ public int SharedKeySize()
+ {
+ return X25519.KeySize;
+ }
+
+ /// <summary>
+ /// Calculate the server message size given an RSA key size
+ /// </summary>
+ /// <param name="keySize">
+ /// Size of the private key (in bits)
+ /// </param>
+ /// <returns>
+ /// Size of the ServerKeyExchange message in bytes
+ /// </returns>
+ private static int CalculateServerMessageSize(int keySize)
+ {
+ int signatureSize = keySize / 8;
+
+ return 0
+ + 1 // ECCurveType ServerKeyExchange.params.curve_params.curve_type
+ + 2 // NamedCurve ServerKeyExchange.params.curve_params.namedcurve
+ + 1 + X25519.KeySize // ECPoint ServerKeyExchange.params.public
+ + 1 // HashAlgorithm ServerKeyExchange.algorithm.hash
+ + 1 // SignatureAlgorithm ServerKeyExchange.signed_params.algorithm.signature
+ + 2 // ServerKeyExchange.signed_params.size
+ + signatureSize // ServerKeyExchange.signed_params.opaque
+ ;
+ }
+
+ /// <inheritdoc />
+ public int CalculateServerMessageSize(object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ return CalculateServerMessageSize(rsaPrivateKey.KeySize);
+ }
+
+ /// <inheritdoc />
+ public void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ output[0] = (byte)ECCurveType.NamedCurve;
+ output.WriteBigEndian16((ushort)NamedCurve.x25519, 1);
+ output[3] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(4, X25519.KeySize), this.privateAgreementKey);
+
+ // Hash the key parameters
+ byte[] paramterDigest = this.sha256.ComputeHash(output.GetUnderlyingArray(), output.Offset, 4 + X25519.KeySize);
+
+ // Sign the paramter digest
+ RSAPKCS1SignatureFormatter signer = new RSAPKCS1SignatureFormatter(rsaPrivateKey);
+ signer.SetHashAlgorithm("SHA256");
+ ByteSpan signature = signer.CreateSignature(paramterDigest);
+
+ Debug.Assert(signature.Length == rsaPrivateKey.KeySize/8);
+ output[4 + X25519.KeySize] = (byte)HashAlgorithm.Sha256;
+ output[5 + X25519.KeySize] = (byte)SignatureAlgorithm.RSA;
+ output.Slice(6+X25519.KeySize).WriteBigEndian16((ushort)signature.Length);
+ signature.CopyTo(output.Slice(8+X25519.KeySize));
+ }
+
+ /// <inheritdoc />
+ public bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey)
+ {
+ RSA rsaPublicKey = publicKey as RSA;
+ if (rsaPublicKey == null)
+ {
+ return false;
+ }
+ else if (output.Length != X25519.KeySize)
+ {
+ return false;
+ }
+
+ // Verify message is compatible with this cipher suite
+ if (serverKeyExchangeMessage.Length != CalculateServerMessageSize(rsaPublicKey.KeySize))
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[0] != (byte)ECCurveType.NamedCurve)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage.ReadBigEndian16(1) != (ushort)NamedCurve.x25519)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[3] != X25519.KeySize)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[4 + X25519.KeySize] != (byte)HashAlgorithm.Sha256)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[5 + X25519.KeySize] != (byte)SignatureAlgorithm.RSA)
+ {
+ return false;
+ }
+
+ ByteSpan keyParameters = serverKeyExchangeMessage.Slice(0, 4+X25519.KeySize);
+ ByteSpan othersPublicKey = keyParameters.Slice(4);
+ ushort signatureSize = serverKeyExchangeMessage.ReadBigEndian16(6 + X25519.KeySize);
+ ByteSpan signature = serverKeyExchangeMessage.Slice(4+keyParameters.Length);
+
+ if (signatureSize != signature.Length)
+ {
+ return false;
+ }
+
+ // Hash the key parameters
+ byte[] parameterDigest = this.sha256.ComputeHash(keyParameters.GetUnderlyingArray(), keyParameters.Offset, keyParameters.Length);
+
+ // Verify the signature
+ RSAPKCS1SignatureDeformatter verifier = new RSAPKCS1SignatureDeformatter(rsaPublicKey);
+ verifier.SetHashAlgorithm("SHA256");
+ if (!verifier.VerifySignature(parameterDigest, signature.ToArray()))
+ {
+ return false;
+ }
+
+ // Signature has been validated, generate the shared key
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+
+ private static int ClientMessageSize = 0
+ + 1 + X25519.KeySize // ECPoint ClientKeyExchange.ecdh_Yc
+ ;
+
+ /// <inheritdoc />
+ public int CalculateClientMessageSize()
+ {
+ return ClientMessageSize;
+ }
+
+ /// <inheritdoc />
+ public void EncodeClientKeyExchangeMessage(ByteSpan output)
+ {
+ output[0] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(1, X25519.KeySize), this.privateAgreementKey);
+ }
+
+ /// <inheritdoc />
+ public bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage)
+ {
+ if (clientKeyExchangeMessage.Length != ClientMessageSize)
+ {
+ return false;
+ }
+ else if (clientKeyExchangeMessage[0] != (byte)X25519.KeySize)
+ {
+ return false;
+ }
+
+ ByteSpan othersPublicKey = clientKeyExchangeMessage.Slice(1);
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+ }
+}