diff options
author | chai <215380520@qq.com> | 2023-10-12 22:09:49 +0800 |
---|---|---|
committer | chai <215380520@qq.com> | 2023-10-12 22:09:49 +0800 |
commit | 8d2a2cd5de40e2b94ef5007c32832ed9a063dc40 (patch) | |
tree | a63dfbe815855925c9fb8f2804bd6ccfeffbd2eb /Tools/Hazel-Networking/Hazel/Dtls | |
parent | dd0c5d50e377d9be1e728463670908a6c9d2c14f (diff) |
+hazel-networking
Diffstat (limited to 'Tools/Hazel-Networking/Hazel/Dtls')
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs | 147 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs | 1424 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs | 1246 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs | 734 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs | 63 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs | 84 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs | 66 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs | 84 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/Record.cs | 123 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs | 159 | ||||
-rw-r--r-- | Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs | 202 |
11 files changed, 4332 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs new file mode 100644 index 0000000..65df39e --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs @@ -0,0 +1,147 @@ +using Hazel.Crypto; +using System; +using System.Diagnostics; + +namespace Hazel.Dtls +{ + /// <summary> + /// *_AES_128_GCM_* cipher suite + /// </summary> + public class Aes128GcmRecordProtection: IRecordProtection + { + private const int ImplicitNonceSize = 4; + private const int ExplicitNonceSize = 8; + + private readonly Aes128Gcm serverWriteCipher; + private readonly Aes128Gcm clientWriteCipher; + + private readonly ByteSpan serverWriteIV; + private readonly ByteSpan clientWriteIV; + + /// <summary> + /// Create a new instance of the AES128_GCM record protection + /// </summary> + /// <param name="masterSecret">Shared secret</param> + /// <param name="serverRandom">Server random data</param> + /// <param name="clientRandom">Client random data</param> + public Aes128GcmRecordProtection(ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom) + { + ByteSpan combinedRandom = new byte[serverRandom.Length + clientRandom.Length]; + serverRandom.CopyTo(combinedRandom); + clientRandom.CopyTo(combinedRandom.Slice(serverRandom.Length)); + + // Expand master_secret to encryption keys + const int ExpandedSize = 0 + + 0 // mac_key_length + + 0 // mac_key_length + + Aes128Gcm.KeySize // enc_key_length + + Aes128Gcm.KeySize // enc_key_length + + ImplicitNonceSize // fixed_iv_length + + ImplicitNonceSize // fixed_iv_length + ; + + ByteSpan expandedKey = new byte[ExpandedSize]; + PrfSha256.ExpandSecret(expandedKey, masterSecret, PrfLabel.KEY_EXPANSION, combinedRandom); + + ByteSpan clientWriteKey = expandedKey.Slice(0, Aes128Gcm.KeySize); + ByteSpan serverWriteKey = expandedKey.Slice(Aes128Gcm.KeySize, Aes128Gcm.KeySize); + this.clientWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize, ImplicitNonceSize); + this.serverWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize + ImplicitNonceSize, ImplicitNonceSize); + + this.serverWriteCipher = new Aes128Gcm(serverWriteKey); + this.clientWriteCipher = new Aes128Gcm(clientWriteKey); + } + + /// <inheritdoc /> + public void Dispose() + { + this.serverWriteCipher.Dispose(); + this.clientWriteCipher.Dispose(); + } + + /// <inheritdoc /> + private static int GetEncryptedSizeImpl(int dataSize) + { + return dataSize + Aes128Gcm.CiphertextOverhead; + } + + /// <inheritdoc /> + public int GetEncryptedSize(int dataSize) + { + return GetEncryptedSizeImpl(dataSize); + } + + private static int GetDecryptedSizeImpl(int dataSize) + { + return dataSize - Aes128Gcm.CiphertextOverhead; + } + + /// <inheritdoc /> + public int GetDecryptedSize(int dataSize) + { + return GetDecryptedSizeImpl(dataSize); + } + + /// <inheritdoc /> + public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + EncryptPlaintext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV); + } + + /// <inheritdoc /> + public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + EncryptPlaintext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV); + } + + private static void EncryptPlaintext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV) + { + Debug.Assert(output.Length >= GetEncryptedSizeImpl(input.Length)); + + // Build GCM nonce (authenticated data) + ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize]; + writeIV.CopyTo(nonce); + nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize); + nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2); + + // Serialize record as additional data + Record plaintextRecord = record; + plaintextRecord.Length = (ushort)input.Length; + ByteSpan associatedData = new byte[Record.Size]; + plaintextRecord.Encode(associatedData); + + cipher.Seal(output, nonce, input, associatedData); + } + + /// <inheritdoc /> + public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record) + { + return DecryptCiphertext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV); + } + + /// <inheritdoc /> + public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record) + { + return DecryptCiphertext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV); + } + + private static bool DecryptCiphertext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV) + { + Debug.Assert(output.Length >= GetDecryptedSizeImpl(input.Length)); + + // Build GCM nonce (authenticated data) + ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize]; + writeIV.CopyTo(nonce); + nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize); + nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2); + + // Serialize record as additional data + Record plaintextRecord = record; + plaintextRecord.Length = (ushort)GetDecryptedSizeImpl(input.Length); + ByteSpan associatedData = new byte[Record.Size]; + plaintextRecord.Encode(associatedData); + + return cipher.Open(output, nonce, input, associatedData); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs new file mode 100644 index 0000000..61f41d3 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs @@ -0,0 +1,1424 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using Hazel.Udp.FewerThreads; +using Hazel.Crypto; + +namespace Hazel.Dtls +{ + /// <summary> + /// Listens for new UDP-DTLS connections and creates UdpConnections for them. + /// </summary> + /// <inheritdoc /> + public class DtlsConnectionListener : ThreadLimitedUdpConnectionListener + { + private const int MaxCertFragmentSizeV0 = 1200; + + // Min MTU - UDP+IP header - 1 (for good measure. :)) + private const int MaxCertFragmentSizeV1 = 576 - 32 - 1; + + /// <summary> + /// Current state of handshake sequence + /// </summary> + enum HandshakeState + { + ExpectingHello, + ExpectingClientKeyExchange, + ExpectingChangeCipherSpec, + ExpectingFinish + } + + /// <summary> + /// State to manage the current epoch `N` + /// </summary> + struct CurrentEpoch + { + public ulong NextOutgoingSequence; + + public ulong NextExpectedSequence; + public ulong PreviousSequenceWindowBitmask; + + public IRecordProtection RecordProtection; + public IRecordProtection PreviousRecordProtection; + + // Need to keep these around so we can re-transmit our + // last handshake record flight + public ByteSpan ExpectedClientFinishedVerification; + public ByteSpan ServerFinishedVerification; + public ulong NextOutgoingSequenceForPreviousEpoch; + } + + /// <summary> + /// State to manage the transition from the current + /// epoch `N` to epoch `N+1` + /// </summary> + struct NextEpoch + { + public ushort Epoch; + + public HandshakeState State; + public CipherSuite SelectedCipherSuite; + + public ulong NextOutgoingSequence; + + public IHandshakeCipherSuite Handshake; + public IRecordProtection RecordProtection; + + public ByteSpan ClientRandom; + public ByteSpan ServerRandom; + + public Sha256Stream VerificationStream; + + public ByteSpan ClientVerification; + public ByteSpan ServerVerification; + + } + + /// <summary> + /// Per-peer state + /// </summary> + sealed class PeerData : IDisposable + { + public ushort Epoch; + public bool CanHandleApplicationData; + + public HazelDtlsSessionInfo Session; + + public CurrentEpoch CurrentEpoch; + public NextEpoch NextEpoch; + + public ConnectionId ConnectionId; + + public readonly List<ByteSpan> QueuedApplicationDataMessage = new List<ByteSpan>(); + public readonly ConcurrentBag<MessageReader> ApplicationData = new ConcurrentBag<MessageReader>(); + public readonly ProtocolVersion ProtocolVersion; + + public DateTime StartOfNegotiation; + + public PeerData(ConnectionId connectionId, ulong nextExpectedSequenceNumber, ProtocolVersion protocolVersion) + { + ByteSpan block = new byte[2 * Finished.Size]; + this.CurrentEpoch.ServerFinishedVerification = block.Slice(0, Finished.Size); + this.CurrentEpoch.ExpectedClientFinishedVerification = block.Slice(Finished.Size, Finished.Size); + this.ProtocolVersion = protocolVersion; + + ResetPeer(connectionId, nextExpectedSequenceNumber); + } + + public void ResetPeer(ConnectionId connectionId, ulong nextExpectedSequenceNumber) + { + Dispose(); + + this.Epoch = 0; + this.CanHandleApplicationData = false; + this.QueuedApplicationDataMessage.Clear(); + + this.CurrentEpoch.NextOutgoingSequence = 2; // Account for our ClientHelloVerify + this.CurrentEpoch.NextExpectedSequence = nextExpectedSequenceNumber; + this.CurrentEpoch.PreviousSequenceWindowBitmask = 0; + this.CurrentEpoch.RecordProtection = NullRecordProtection.Instance; + this.CurrentEpoch.PreviousRecordProtection = null; + this.CurrentEpoch.ServerFinishedVerification.SecureClear(); + this.CurrentEpoch.ExpectedClientFinishedVerification.SecureClear(); + + this.NextEpoch.State = HandshakeState.ExpectingHello; + this.NextEpoch.RecordProtection = null; + this.NextEpoch.Handshake = null; + this.NextEpoch.ClientRandom = new byte[Random.Size]; + this.NextEpoch.ServerRandom = new byte[Random.Size]; + this.NextEpoch.VerificationStream = new Sha256Stream(); + this.NextEpoch.ClientVerification = new byte[Finished.Size]; + this.NextEpoch.ServerVerification = new byte[Finished.Size]; + + this.ConnectionId = connectionId; + + this.StartOfNegotiation = DateTime.UtcNow; + } + + public void Dispose() + { + this.CurrentEpoch.RecordProtection?.Dispose(); + this.CurrentEpoch.PreviousRecordProtection?.Dispose(); + this.NextEpoch.RecordProtection?.Dispose(); + this.NextEpoch.Handshake?.Dispose(); + this.NextEpoch.VerificationStream?.Dispose(); + + while (this.ApplicationData.TryTake(out var msg)) + { + try + { + msg.Recycle(); + } + catch { } + } + } + } + + private RandomNumberGenerator random; + + // Private key component of certificate's public key + private ByteSpan encodedCertificate; + private RSA certificatePrivateKey; + + // HMAC key to validate ClientHello cookie + private ThreadedHmacHelper hmacHelper; + private HMAC CurrentCookieHmac { + get + { + return hmacHelper.GetCurrentCookieHmacsForThread(); + } + } + private HMAC PreviousCookieHmac + { + get + { + return hmacHelper.GetPreviousCookieHmacsForThread(); + } + } + + private ConcurrentStack<ConnectionId> staleConnections = new ConcurrentStack<ConnectionId>(); + private readonly ConcurrentDictionary<IPEndPoint, PeerData> existingPeers = new ConcurrentDictionary<IPEndPoint, PeerData>(); + public int PeerCount => this.existingPeers.Count; + + // TODO: Move these into an DtlsErrorStatistics class + public int NonPeerNonHelloPacketsDropped; + public int NonVerifiedFinishedHandshake; + public int NonPeerVerifyHelloRequests; + public int PeerVerifyHelloRequests; + + private int connectionSerial_unsafe = 0; + + private Timer staleConnectionUpkeep; + + /// <summary> + /// Create a new instance of the DTLS listener + /// </summary> + /// <param name="numWorkers"></param> + /// <param name="endPoint"></param> + /// <param name="logger"></param> + /// <param name="ipMode"></param> + public DtlsConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4) + : base(numWorkers, endPoint, logger, ipMode) + { + this.random = RandomNumberGenerator.Create(); + + this.staleConnectionUpkeep = new Timer(this.HandleStaleConnections, null, 2500, 1000); + this.hmacHelper = new ThreadedHmacHelper(logger); + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + this.staleConnectionUpkeep.Dispose(); + + this.random?.Dispose(); + this.random = null; + + this.hmacHelper?.Dispose(); + this.hmacHelper = null; + + foreach (var pair in this.existingPeers) + { + pair.Value.Dispose(); + } + this.existingPeers.Clear(); + } + + /// <summary> + /// Set the certificate key pair for the listener + /// </summary> + /// <param name="certificate">Certificate for the server</param> + public void SetCertificate(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new ArgumentException("Certificate must have a private key attached", nameof(certificate)); + } + + RSA privateKey = certificate.GetRSAPrivateKey(); + if (privateKey == null) + { + throw new ArgumentException("Certificate must be signed by an RSA key", nameof(certificate)); + } + + this.certificatePrivateKey?.Dispose(); + this.certificatePrivateKey = privateKey; + + this.encodedCertificate = Certificate.Encode(certificate); + } + + /// <summary> + /// Handle an incoming datagram from the network. + /// + /// This is primarily a wrapper around ProcessIncomingMessage + /// to ensure `reader.Recycle()` is always called + /// </summary> + protected override void ReadCallback(MessageReader reader, IPEndPoint peerAddress, ConnectionId connectionId) + { + try + { + ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining); + this.ProcessIncomingMessage(message, peerAddress); + } + finally + { + reader.Recycle(); + } + } + + /// <summary> + /// Handle an incoming datagram from the network + /// </summary> + private void ProcessIncomingMessage(ByteSpan message, IPEndPoint peerAddress) + { + PeerData peer = null; + if (!this.existingPeers.TryGetValue(peerAddress, out peer)) + { + lock (this.existingPeers) + { + if (!this.existingPeers.TryGetValue(peerAddress, out peer)) + { + HandleNonPeerRecord(message, peerAddress); + return; + } + } + } + + ConnectionId peerConnectionId; + + lock (peer) + { + peerConnectionId = peer.ConnectionId; + + // Each incoming packet may contain multiple DTLS + // records + while (message.Length > 0) + { + Record record; + if (!Record.Parse(out record, peer.ProtocolVersion, message)) + { + this.Logger.WriteError($"Dropping malformed record from `{peerAddress}`"); + return; + } + message = message.Slice(Record.Size); + + if (message.Length < record.Length) + { + this.Logger.WriteError($"Dropping malformed record from `{peerAddress}` Length({record.Length}) AvailableBytes({message.Length})"); + return; + } + + ByteSpan recordPayload = message.Slice(0, record.Length); + message = message.Slice(record.Length); + + // Early-out and drop ApplicationData records + if (record.ContentType == ContentType.ApplicationData && !peer.CanHandleApplicationData) + { + this.Logger.WriteInfo($"Dropping ApplicationData record from `{peerAddress}` Cannot process yet"); + continue; + } + + // Drop records from a different epoch + if (record.Epoch != peer.Epoch) + { + // Handle existing client negotiating a new connection + if (record.Epoch == 0 && record.ContentType == ContentType.Handshake) + { + ByteSpan handshakePayload = recordPayload; + + Handshake handshake; + if (!Handshake.Parse(out handshake, recordPayload)) + { + this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`"); + continue; + } + handshakePayload = handshakePayload.Slice(Handshake.Size); + + if (handshake.FragmentOffset != 0 || handshake.Length != handshake.FragmentLength) + { + this.Logger.WriteError($"Dropping fragmented re-negotiation Handshake from `{peerAddress}`"); + continue; + } + else if (handshake.MessageType != HandshakeType.ClientHello) + { + this.Logger.WriteVerbose($"Dropping non-ClientHello re-negotiation Handshake from `{peerAddress}`"); + continue; + } + else if (handshakePayload.Length < handshake.Length) + { + this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`: Length({handshake.Length}) AvailableBytes({handshakePayload.Length})"); + } + + if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, recordPayload, handshakePayload)) + { + return; + } + continue; + } + + this.Logger.WriteVerbose($"Dropping bad-epoch record from `{peerAddress}` RecordEpoch({record.Epoch}) CurrentEpoch({peer.Epoch})"); + continue; + } + + // Prevent replay attacks by dropping records + // we've already processed + int windowIndex = (int)(peer.CurrentEpoch.NextExpectedSequence - record.SequenceNumber - 1); + ulong windowMask = 1ul << windowIndex; + if (record.SequenceNumber < peer.CurrentEpoch.NextExpectedSequence) + { + if (windowIndex >= 64) + { + this.Logger.WriteInfo($"Dropping too-old record from `{peerAddress}` Sequence({record.SequenceNumber}) Expected({peer.CurrentEpoch.NextExpectedSequence})"); + continue; + } + + if ((peer.CurrentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0) + { + this.Logger.WriteInfo($"Dropping duplicate record from `{peerAddress}`"); + continue; + } + } + + // Validate record authenticity + int decryptedSize = peer.CurrentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length); + if (decryptedSize < 0) + { + this.Logger.WriteInfo($"Dropping malformed record: Length {recordPayload.Length} Decrypted length: {decryptedSize}"); + continue; + } + + ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize); + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + if (!peer.CurrentEpoch.RecordProtection.DecryptCiphertextFromClient(decryptedPayload, recordPayload, ref record)) + { + this.Logger.WriteVerbose($"Dropping non-authentic {record.ContentType} record from `{peerAddress}`"); + return; + } + + recordPayload = decryptedPayload; + + // Update our squence number bookeeping + if (record.SequenceNumber >= peer.CurrentEpoch.NextExpectedSequence) + { + int windowShift = (int)(record.SequenceNumber + 1 - peer.CurrentEpoch.NextExpectedSequence); + peer.CurrentEpoch.PreviousSequenceWindowBitmask <<= windowShift; + peer.CurrentEpoch.NextExpectedSequence = record.SequenceNumber + 1; + } + else + { + peer.CurrentEpoch.PreviousSequenceWindowBitmask |= windowMask; + } + + // This is handy for debugging, but too verbose even for verbose. + // this.Logger.WriteVerbose($"Record type {record.ContentType} ({peer.NextEpoch.State})"); + switch (record.ContentType) + { + case ContentType.ChangeCipherSpec: + if (peer.NextEpoch.State != HandshakeState.ExpectingChangeCipherSpec) + { + this.Logger.WriteError($"Dropping unexpected ChangeChiperSpec record from `{peerAddress}` State({peer.NextEpoch.State})"); + break; + } + else if (peer.NextEpoch.RecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?"); + + this.Logger.WriteError($"Dropping ChangeCipherSpec message from `{peerAddress}`: No pending record protection"); + break; + } + + if (!ChangeCipherSpec.Parse(recordPayload)) + { + this.Logger.WriteError($"Dropping malformed ChangeCipherSpec message from `{peerAddress}`"); + break; + } + + // Migrate to the next epoch + peer.Epoch = peer.NextEpoch.Epoch; + peer.CanHandleApplicationData = false; // Need a Finished message + peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch = peer.CurrentEpoch.NextOutgoingSequence; + peer.CurrentEpoch.PreviousRecordProtection?.Dispose(); + peer.CurrentEpoch.PreviousRecordProtection = peer.CurrentEpoch.RecordProtection; + peer.CurrentEpoch.RecordProtection = peer.NextEpoch.RecordProtection; + peer.CurrentEpoch.NextOutgoingSequence = 1; + peer.CurrentEpoch.NextExpectedSequence = 1; + peer.CurrentEpoch.PreviousSequenceWindowBitmask = 0; + peer.NextEpoch.ClientVerification.CopyTo(peer.CurrentEpoch.ExpectedClientFinishedVerification); + peer.NextEpoch.ServerVerification.CopyTo(peer.CurrentEpoch.ServerFinishedVerification); + + peer.NextEpoch.State = HandshakeState.ExpectingHello; + peer.NextEpoch.Handshake?.Dispose(); + peer.NextEpoch.Handshake = null; + peer.NextEpoch.NextOutgoingSequence = 1; + peer.NextEpoch.RecordProtection = null; + peer.NextEpoch.VerificationStream.Reset(); + peer.NextEpoch.ClientVerification.SecureClear(); + peer.NextEpoch.ServerVerification.SecureClear(); + break; + + case ContentType.Alert: + this.Logger.WriteError($"Dropping unsupported Alert record from `{peerAddress}`"); + break; + + case ContentType.Handshake: + if (!ProcessHandshake(peer, peerAddress, ref record, recordPayload)) + { + return; + } + break; + + case ContentType.ApplicationData: + // Forward data to the application + MessageReader reader = MessageReader.GetSized(recordPayload.Length); + reader.Length = recordPayload.Length; + recordPayload.CopyTo(reader.Buffer); + + peer.ApplicationData.Add(reader); + break; + } + } + } + + // The peer lock must be exited before leaving the DtlsConnectionListener context to prevent deadlocks + // because ApplicationData processing may reenter this context + while (peer.ApplicationData.TryTake(out var appMsg)) + { + base.ReadCallback(appMsg, peerAddress, peerConnectionId); + } + } + + /// <summary> + /// Process an incoming Handshake protocol message + /// </summary> + /// <param name="peer">Originating peer</param> + /// <param name="peerAddress">Peer's network address</param> + /// <param name="record">Parent record</param> + /// <param name="message">Record payload</param> + /// <returns> + /// True if further processing of the underlying datagram + /// should be continues. Otherwise, false. + /// </returns> + private bool ProcessHandshake(PeerData peer, IPEndPoint peerAddress, ref Record record, ByteSpan message) + { + // Each record may have multiple handshake payloads + while (message.Length > 0) + { + ByteSpan originalMessage = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`"); + return false; + } + message = message.Slice(Handshake.Size); + + if (message.Length < handshake.Length) + { + this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`"); + return false; + } + + ByteSpan payload = message.Slice(0, (int)message.Length); + message = message.Slice((int)handshake.Length); + originalMessage = originalMessage.Slice(0, Handshake.Size + (int)handshake.Length); + + // We do not support fragmented handshake messages + // from the client + if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length) + { + this.Logger.WriteError($"Dropping fragmented Handshake message from `{peerAddress}` Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})"); + continue; + } + + ByteSpan packet; + ByteSpan writer; + +#if DEBUG + this.Logger.WriteVerbose($"Received handshake {handshake.MessageType} ({peer.NextEpoch.State})"); +#endif + switch (handshake.MessageType) + { + case HandshakeType.ClientHello: + if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, originalMessage, payload)) + { + return false; + } + break; + + case HandshakeType.ClientKeyExchange: + if (peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange) + { + this.Logger.WriteError($"Dropping unexpected ClientKeyExchange message form `{peerAddress}` State({peer.NextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 5) + { + this.Logger.WriteError($"Dropping bad-sequence ClientKeyExchange message from `{peerAddress}` MessageSequence({handshake.MessageSequence})"); + continue; + } + + ByteSpan sharedSecret = new byte[peer.NextEpoch.Handshake.SharedKeySize()]; + if (!peer.NextEpoch.Handshake.VerifyClientMessageAndGenerateSharedKey(sharedSecret, payload)) + { + this.Logger.WriteError($"Dropping malformed ClientKeyExchange message from `{peerAddress}`"); + return false; + } + + // Record incoming ClientKeyExchange message + // to verification stream + peer.NextEpoch.VerificationStream.AddData(originalMessage); + + ByteSpan randomSeed = new byte[2 * Random.Size]; + peer.NextEpoch.ClientRandom.CopyTo(randomSeed); + peer.NextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size)); + + const int MasterSecretSize = 48; + ByteSpan masterSecret = new byte[MasterSecretSize]; + PrfSha256.ExpandSecret( + masterSecret + , sharedSecret + , PrfLabel.MASTER_SECRET + , randomSeed + ); + + // Create the record protection for the upcoming epoch + switch (peer.NextEpoch.SelectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + peer.NextEpoch.RecordProtection = new Aes128GcmRecordProtection( + masterSecret + , peer.NextEpoch.ServerRandom + , peer.NextEpoch.ClientRandom); + break; + + default: + Debug.Assert(false, $"How did we agree to a cipher suite {peer.NextEpoch.SelectedCipherSuite} we can't create?"); + this.Logger.WriteError($"Dropping ClientKeyExchange message from `{peerAddress}` Unsuppored cipher suite"); + return false; + } + + // Generate verification signatures + ByteSpan handshakeStreamHash = new byte[Sha256Stream.DigestSize]; + peer.NextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeStreamHash); + + PrfSha256.ExpandSecret( + peer.NextEpoch.ClientVerification + , masterSecret + , PrfLabel.CLIENT_FINISHED + , handshakeStreamHash + ); + PrfSha256.ExpandSecret( + peer.NextEpoch.ServerVerification + , masterSecret + , PrfLabel.SERVER_FINISHED + , handshakeStreamHash + ); + + + // Update handshake state + masterSecret.SecureClear(); + peer.NextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + break; + + case HandshakeType.Finished: + // Unlike other handshake messages, this is + // for the current epoch - not the next epoch + + // Cannot process a Finished message for + // epoch 0 + if (peer.Epoch == 0) + { + this.Logger.WriteError($"Dropping Finished message for 0-epoch from `{peerAddress}`"); + continue; + } + // Cannot process a Finished message when we + // are negotiating the next epoch + else if (peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + this.Logger.WriteError($"Dropping Finished message while negotiating new epoch from `{peerAddress}`"); + continue; + } + // Cannot process a Finished message without + // verify data + else if (peer.CurrentEpoch.ExpectedClientFinishedVerification.Length != Finished.Size || peer.CurrentEpoch.ServerFinishedVerification.Length != Finished.Size) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How do we have an established non-zero epoch without verify data?"); + + this.Logger.WriteError($"Dropping Finished message (no verify data) from `{peerAddress}`"); + return false; + } + // Cannot process a Finished message without + // record protection for the previous epoch + else if (peer.CurrentEpoch.PreviousRecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed server. + Debug.Assert(false, "How do we have an established non-zero epoch with record protection for the previous epoch?"); + + this.Logger.WriteError($"Dropping Finished message from `{peerAddress}`: No previous epoch record protection"); + return false; + } + + // Verify message sequence + if (handshake.MessageSequence != 6) + { + this.Logger.WriteError($"Dropping bad-sequence Finished message from `{peerAddress}` MessageSequence({handshake.MessageSequence})"); + continue; + } + + // Verify the client has the correct + // handshake sequence + if (payload.Length != Finished.Size) + { + this.Logger.WriteError($"Dropping malformed Finished message from `{peerAddress}`"); + return false; + } + else if (1 != Crypto.Const.ConstantCompareSpans(payload, peer.CurrentEpoch.ExpectedClientFinishedVerification)) + { + +#if DEBUG + this.Logger.WriteError($"Dropping non-verified Finished Handshake from `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonVerifiedFinishedHandshake); +#endif + + // Abort the connection here + // + // The client is either broken, or + // doen not agree on our epoch settings. + // + // Either way, there is not a feasible + // way to progress the connection. + MarkConnectionAsStale(peer.ConnectionId); + this.existingPeers.TryRemove(peerAddress, out _); + + return false; + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + // Describe our ChangeCipherSpec+Finished + Handshake outgoingHandshake = new Handshake(); + outgoingHandshake.MessageType = HandshakeType.Finished; + outgoingHandshake.Length = Finished.Size; + outgoingHandshake.MessageSequence = 7; + outgoingHandshake.FragmentOffset = 0; + outgoingHandshake.FragmentLength = outgoingHandshake.Length; + + Record changeCipherSpecRecord = new Record(); + changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec; + changeCipherSpecRecord.ProtocolVersion = protocolVersion; + changeCipherSpecRecord.Epoch = (ushort)(peer.Epoch - 1); + changeCipherSpecRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + changeCipherSpecRecord.Length = (ushort)peer.CurrentEpoch.PreviousRecordProtection.GetEncryptedSize(ChangeCipherSpec.Size); + ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + + int plaintextFinishedPayloadSize = Handshake.Size + (int)outgoingHandshake.Length; + Record finishedRecord = new Record(); + finishedRecord.ContentType = ContentType.Handshake; + finishedRecord.ProtocolVersion = protocolVersion; + finishedRecord.Epoch = peer.Epoch; + finishedRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + finishedRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(plaintextFinishedPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the flight into wire format + packet = new byte[Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length]; + writer = packet; + changeCipherSpecRecord.Encode(writer); + writer = writer.Slice(Record.Size); + ChangeCipherSpec.Encode(writer); + + ByteSpan startOfFinishedRecord = packet.Slice(Record.Size + changeCipherSpecRecord.Length); + writer = startOfFinishedRecord; + finishedRecord.Encode(writer); + writer = writer.Slice(Record.Size); + outgoingHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + peer.CurrentEpoch.ServerFinishedVerification.CopyTo(writer); + + // Protect the ChangeChipherSpec record + peer.CurrentEpoch.PreviousRecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, changeCipherSpecRecord.Length), + packet.Slice(Record.Size, ChangeCipherSpec.Size), + ref changeCipherSpecRecord + ); + + // Protect the Finished Handshake record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length), + startOfFinishedRecord.Slice(Record.Size, plaintextFinishedPayloadSize), + ref finishedRecord + ); + + // Current epoch can now handle application data + peer.CanHandleApplicationData = true; + + base.QueueRawData(packet, peerAddress); + break; + + // Drop messages that we do not support + case HandshakeType.CertificateVerify: + this.Logger.WriteError($"Dropping unsupported Handshake message from `{peerAddress}` MessageType({handshake.MessageType})"); + continue; + + // Drop messages that originate from the server + case HandshakeType.HelloRequest: + case HandshakeType.ServerHello: + case HandshakeType.HelloVerifyRequest: + case HandshakeType.Certificate: + case HandshakeType.ServerKeyExchange: + case HandshakeType.CertificateRequest: + case HandshakeType.ServerHelloDone: + this.Logger.WriteError($"Dropping server Handshake message from `{peerAddress}` MessageType({handshake.MessageType})"); + continue; + } + } + + return true; + } + + /// <summary> + /// Handle a ClientHello message for a peer + /// </summary> + /// <param name="peer">Originating peer</param> + /// <param name="peerAddress">Peer address</param> + /// <param name="record">Parent record</param> + /// <param name="handshake">Parent Handshake header</param> + /// <param name="payload">Handshake payload</param> + private bool HandleClientHello(PeerData peer, IPEndPoint peerAddress, ref Record record, ref Handshake handshake, ByteSpan originalMessage, ByteSpan payload) + { + // Verify message sequence + if (handshake.MessageSequence != 0) + { + this.Logger.WriteError($"Dropping bad-sequence ClientHello from `{peerAddress}` MessageSequence({handshake.MessageSequence})`"); + return true; + } + + // Make sure we can handle a ClientHello message + if (peer.NextEpoch.State != HandshakeState.ExpectingHello && peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange) + { + // Always handle ClientHello for epoch 0 + if (record.Epoch != 0) + { + this.Logger.WriteError($"Dropping ClientHello from `{peer}` Not expecting ClientHello"); + return true; + } + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + if (!ClientHello.Parse(out ClientHello clientHello, protocolVersion, payload)) + { + this.Logger.WriteError($"Dropping malformed ClientHello Handshake message from `{peerAddress}`"); + return false; + } + + // Find an acceptable cipher suite we can use + CipherSuite selectedCipherSuite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256; + if (!clientHello.ContainsCipherSuite(CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256) || !clientHello.ContainsCurve(NamedCurve.x25519)) + { + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` No compatible cipher suite"); + return false; + } + + // If this message was not signed by us, + // request a signed message before doing anything else + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac)) + { + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac)) + { + ulong outgoingSequence = 1; + IRecordProtection recordProtection = NullRecordProtection.Instance; + if (record.Epoch != 0) + { + outgoingSequence = peer.CurrentEpoch.NextExpectedSequence; + ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch; + + recordProtection = peer.CurrentEpoch.RecordProtection; + } + +#if DEBUG + this.Logger.WriteError($"Sending HelloVerifyRequest to peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.PeerVerifyHelloRequests); +#endif + this.SendHelloVerifyRequest(peerAddress, outgoingSequence, record.Epoch, recordProtection, protocolVersion); + return true; + } + } + + // Client is initiating a brand new connection. We need + // to destroy the existing connection and establish a + // new session. + if (record.Epoch == 0 && peer.Epoch != 0) + { + ConnectionId oldConnectionId = peer.ConnectionId; + peer.ResetPeer(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1); + + // Inform the parent layer that the existing + // connection should be abandoned. + MarkConnectionAsStale(oldConnectionId); + } + + // Determine if this is an original message, or a retransmission + bool recordMessagesForVerifyData = false; + if (peer.NextEpoch.State == HandshakeState.ExpectingHello) + { + // Create our handhake cipher suite + IHandshakeCipherSuite handshakeCipherSuite = null; + switch (selectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + if (clientHello.ContainsCurve(NamedCurve.x25519)) + { + handshakeCipherSuite = new X25519EcdheRsaSha256(this.random); + } + else + { + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 cipher suite"); + return false; + } + + break; + + default: + this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create handshake cipher suite"); + return false; + } + + peer.Session = clientHello.Session; + + // Update the state of our epoch transition + peer.NextEpoch.Epoch = (ushort)(record.Epoch + 1); + peer.NextEpoch.State = HandshakeState.ExpectingClientKeyExchange; + peer.NextEpoch.SelectedCipherSuite = selectedCipherSuite; + peer.NextEpoch.Handshake = handshakeCipherSuite; + clientHello.Random.CopyTo(peer.NextEpoch.ClientRandom); + peer.NextEpoch.ServerRandom.FillWithRandom(this.random); + recordMessagesForVerifyData = true; + +#if DEBUG + this.Logger.WriteVerbose($"ClientRandom: {peer.NextEpoch.ClientRandom} ServerRandom: {peer.NextEpoch.ServerRandom}"); +#endif + + // Copy the original ClientHello + // handshake to our verification stream + peer.NextEpoch.VerificationStream.AddData( + originalMessage.Slice( + 0 + , Handshake.Size + (int)handshake.Length + ) + ); + } + + // The initial record flight from the server + // contains the following Handshake messages: + // * ServerHello + // * Certificate + // * ServerKeyExchange + // * ServerHelloDone + // + // The Certificate message is almost always + // too large to fit into a single datagram, + // so it is pre-fragmented + // (see `SetCertificates`). Therefore, we + // need to send multiple record packets for + // this flight. + // + // The first record contains the ServerHello + // handshake message, as well as the first + // portion of the Certificate message. + // + // We then send a record packet until the + // entire Certificate message has been sent + // to the client. + // + // The final record packet contains the + // ServerKeyExchange and the ServerHelloDone + // messages. + + // Describe first record of the flight + ServerHello serverHello = new ServerHello(); + serverHello.ServerProtocolVersion = protocolVersion; + serverHello.Random = peer.NextEpoch.ServerRandom; + serverHello.CipherSuite = selectedCipherSuite; + + Handshake serverHelloHandshake = new Handshake(); + serverHelloHandshake.MessageType = HandshakeType.ServerHello; + serverHelloHandshake.Length = ServerHello.MinSize; + serverHelloHandshake.MessageSequence = 1; + serverHelloHandshake.FragmentOffset = 0; + serverHelloHandshake.FragmentLength = serverHelloHandshake.Length; + + int maxCertFragmentSize = peer.Session.Version == 0 ? MaxCertFragmentSizeV0 : MaxCertFragmentSizeV1; + + // The first certificate data needs to leave room for + // * Record header + // * ServerHello header + // * ServerHello payload + // * Certificate header + + var certificateData = this.encodedCertificate; + int initialCertPadding = Record.Size + Handshake.Size + serverHello.Size + Handshake.Size; + int certInitialFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - initialCertPadding); + + Handshake certificateHandshake = new Handshake(); + certificateHandshake.MessageType = HandshakeType.Certificate; + certificateHandshake.Length = (uint)certificateData.Length; + certificateHandshake.MessageSequence = 2; + certificateHandshake.FragmentOffset = 0; + certificateHandshake.FragmentLength = (uint)certInitialFragmentSize; + + int initialRecordPayloadSize = 0 + + Handshake.Size + serverHello.Size + + Handshake.Size + (int)certificateHandshake.FragmentLength + ; + Record initialRecord = new Record(); + initialRecord.ContentType = ContentType.Handshake; + initialRecord.ProtocolVersion = protocolVersion; + initialRecord.Epoch = peer.Epoch; + initialRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + initialRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(initialRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert initial record of the flight to + // wire format + ByteSpan packet = new byte[Record.Size + initialRecord.Length]; + ByteSpan writer = packet; + initialRecord.Encode(writer); + writer = writer.Slice(Record.Size); + serverHelloHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + serverHello.Encode(writer); + writer = writer.Slice(ServerHello.MinSize); + certificateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + certificateData.Slice(0, certInitialFragmentSize).CopyTo(writer); + certificateData = certificateData.Slice(certInitialFragmentSize); + + // Protect initial record of the flight + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, initialRecord.Length), + packet.Slice(Record.Size, initialRecordPayloadSize), + ref initialRecord + ); + + base.QueueRawData(packet, peerAddress); + + // Record record payload for verification + if (recordMessagesForVerifyData) + { + Handshake fullCeritficateHandshake = certificateHandshake; + fullCeritficateHandshake.FragmentLength = fullCeritficateHandshake.Length; + + packet = new byte[Handshake.Size + ServerHello.MinSize + Handshake.Size]; + writer = packet; + serverHelloHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + serverHello.Encode(writer); + writer = writer.Slice(ServerHello.MinSize); + fullCeritficateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + peer.NextEpoch.VerificationStream.AddData(packet); + peer.NextEpoch.VerificationStream.AddData(this.encodedCertificate); + } + + // Process additional certificate records + // Subsequent certificate data needs to leave room for + // * Record header + // * Certificate header + const int CertPadding = Record.Size + Handshake.Size; + while (certificateData.Length > 0) + { + int certFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - CertPadding); + + certificateHandshake.FragmentOffset += certificateHandshake.FragmentLength; + certificateHandshake.FragmentLength = (uint)certFragmentSize; + + int additionalRecordPayloadSize = Handshake.Size + (int)certificateHandshake.FragmentLength; + Record additionalRecord = new Record(); + additionalRecord.ContentType = ContentType.Handshake; + additionalRecord.ProtocolVersion = protocolVersion; + additionalRecord.Epoch = peer.Epoch; + additionalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + additionalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(additionalRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert record to wire format + packet = new byte[Record.Size + additionalRecord.Length]; + writer = packet; + additionalRecord.Encode(writer); + writer = writer.Slice(Record.Size); + certificateHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + certificateData.Slice(0, certFragmentSize).CopyTo(writer); + + certificateData = certificateData.Slice(certFragmentSize); + + // Protect record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, additionalRecord.Length), + packet.Slice(Record.Size, additionalRecordPayloadSize), + ref additionalRecord + ); + + base.QueueRawData(packet, peerAddress); + } + + // Describe final record of the flight + Handshake serverKeyExchangeHandshake = new Handshake(); + serverKeyExchangeHandshake.MessageType = HandshakeType.ServerKeyExchange; + serverKeyExchangeHandshake.Length = (uint)peer.NextEpoch.Handshake.CalculateServerMessageSize(this.certificatePrivateKey); + serverKeyExchangeHandshake.MessageSequence = 3; + serverKeyExchangeHandshake.FragmentOffset = 0; + serverKeyExchangeHandshake.FragmentLength = serverKeyExchangeHandshake.Length; + + Handshake serverHelloDoneHandshake = new Handshake(); + serverHelloDoneHandshake.MessageType = HandshakeType.ServerHelloDone; + serverHelloDoneHandshake.Length = 0; + serverHelloDoneHandshake.MessageSequence = 4; + serverHelloDoneHandshake.FragmentOffset = 0; + serverHelloDoneHandshake.FragmentLength = 0; + + int finalRecordPayloadSize = 0 + + Handshake.Size + (int)serverKeyExchangeHandshake.Length + + Handshake.Size + (int)serverHelloDoneHandshake.Length + ; + Record finalRecord = new Record(); + finalRecord.ContentType = ContentType.Handshake; + finalRecord.ProtocolVersion = protocolVersion; + finalRecord.Epoch = peer.Epoch; + finalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + finalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(finalRecordPayloadSize); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Convert final record of the flight to wire + // format + packet = new byte[Record.Size + finalRecord.Length]; + writer = packet; + finalRecord.Encode(writer); + writer = writer.Slice(Record.Size); + serverKeyExchangeHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + peer.NextEpoch.Handshake.EncodeServerKeyExchangeMessage(writer, this.certificatePrivateKey); + writer = writer.Slice((int)serverKeyExchangeHandshake.Length); + serverHelloDoneHandshake.Encode(writer); + + // Record record payload for verification + if (recordMessagesForVerifyData) + { + peer.NextEpoch.VerificationStream.AddData( + packet.Slice( + packet.Offset + Record.Size + , finalRecordPayloadSize + ) + ); + } + + // Protect final record of the flight + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, finalRecord.Length), + packet.Slice(Record.Size, finalRecordPayloadSize), + ref finalRecord + ); + + base.QueueRawData(packet, peerAddress); + + return true; + } + + /// <summary> + /// Handle an incoming packet that is not tied to an existing peer + /// </summary> + /// <param name="message">Incoming datagram</param> + /// <param name="peerAddress">Originating address</param> + private void HandleNonPeerRecord(ByteSpan message, IPEndPoint peerAddress) + { + Record record; + if (!Record.Parse(out record, expectedProtocolVersion: null, message)) + { + this.Logger.WriteError($"Dropping malformed record from non-peer `{peerAddress}`"); + return; + } + message = message.Slice(Record.Size); + + // The protocol only supports receiving a single record + // from a non-peer. + if (record.Length != message.Length) + { + // NOTE(mendsley): This isn't always fatal. + // However, this is an indication that something + // fishy is going on. In the best case, there's a + // bug on the client or in the UDP stack (some + // stacks don't both to verify the checksum). In the + // worst case we're dealing with a malicious actor. + // In the malicious case, we'll end up dropping the + // connection later in the process. + if (message.Length < record.Length) + { + this.Logger.WriteInfo($"Dropping bad record from non-peer `{peerAddress}`. Msg length {message.Length} < {record.Length}"); + return; + } + } + + // We only accept zero-epoch records from non-peers + if (record.Epoch != 0) + { + ///NOTE(mendsley): Not logging anything here, as + /// this could easily be latent data arriving from a + /// recently disconnected peer. + return; + } + + // We only accept Handshake protocol messages from non-peers + if (record.ContentType != ContentType.Handshake) + { + this.Logger.WriteError($"Dropping non-handhsake message from non-peer `{peerAddress}`"); + return; + } + + ByteSpan originalMessage = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.Logger.WriteError($"Dropping malformed handshake message from non-peer `{peerAddress}`"); + return; + } + + // We only accept ClientHello messages from non-peers + if (handshake.MessageType != HandshakeType.ClientHello) + { +#if DEBUG + this.Logger.WriteError($"Dropping non-ClientHello ({handshake.MessageType}) message from non-peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonPeerNonHelloPacketsDropped); +#endif + return; + } + message = message.Slice(Handshake.Size); + + if (!ClientHello.Parse(out ClientHello clientHello, expectedProtocolVersion: null, message)) + { + this.Logger.WriteError($"Dropping malformed ClientHello message from non-peer `{peerAddress}`"); + return; + } + + // If this ClientHello is not signed by us, request the + // client send us a signed message + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac)) + { + if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac)) + { +#if DEBUG + this.Logger.WriteVerbose($"Sending HelloVerifyRequest to non-peer `{peerAddress}`"); +#else + Interlocked.Increment(ref this.NonPeerVerifyHelloRequests); +#endif + this.SendHelloVerifyRequest(peerAddress, 1, 0, NullRecordProtection.Instance, clientHello.ClientProtocolVersion); + return; + } + } + + // Allocate state for the new peer and register it + PeerData peer = new PeerData(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1, clientHello.ClientProtocolVersion); + this.ProcessHandshake(peer, peerAddress, ref record, originalMessage); + this.existingPeers[peerAddress] = peer; + } + + //Send a HelloVerifyRequest handshake message to a peer + private void SendHelloVerifyRequest(IPEndPoint peerAddress, ulong recordSequence, ushort epoch, IRecordProtection recordProtection, ProtocolVersion protocolVersion) + { + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.HelloVerifyRequest; + handshake.Length = HelloVerifyRequest.Size; + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + int plaintextPayloadSize = Handshake.Size + (int)handshake.Length; + + Record record = new Record(); + record.ContentType = ContentType.Handshake; + record.ProtocolVersion = protocolVersion; + record.Epoch = epoch; + record.SequenceNumber = recordSequence; + record.Length = (ushort)recordProtection.GetEncryptedSize(plaintextPayloadSize); + + // Encode record to wire format + ByteSpan packet = new byte[Record.Size + record.Length]; + ByteSpan writer = packet; + record.Encode(writer); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + HelloVerifyRequest.Encode(writer, peerAddress, this.CurrentCookieHmac, protocolVersion); + + // Protect record payload + recordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, record.Length), + packet.Slice(Record.Size, plaintextPayloadSize), + ref record + ); + + base.QueueRawData(packet, peerAddress); + } + + /// <summary> + /// Handle a requrest to send a datagram to the network + /// </summary> + protected override void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint) + { + PeerData peer; + if (!this.existingPeers.TryGetValue(remoteEndPoint, out peer)) + { + // Drop messages if we don't know how to send them + return; + } + + lock (peer) + { + // If we're negotiating a new epoch, queue data + if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + ByteSpan copyOfSpan = new byte[span.Length]; + span.CopyTo(copyOfSpan); + + peer.QueuedApplicationDataMessage.Add(copyOfSpan); + return; + } + + ProtocolVersion protocolVersion = peer.ProtocolVersion; + + // Send any queued application data now + for (int ii = 0, nn = peer.QueuedApplicationDataMessage.Count; ii != nn; ++ii) + { + ByteSpan queuedSpan = peer.QueuedApplicationDataMessage[ii]; + + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.ApplicationData; + outgoingRecord.ProtocolVersion = protocolVersion; + outgoingRecord.Epoch = peer.Epoch; + outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(queuedSpan.Length); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(writer); + writer = writer.Slice(Record.Size); + queuedSpan.CopyTo(writer); + + // Protect the record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, queuedSpan.Length), + ref outgoingRecord + ); + + base.QueueRawData(packet, remoteEndPoint); + } + peer.QueuedApplicationDataMessage.Clear(); + + { + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.ApplicationData; + outgoingRecord.ProtocolVersion = protocolVersion; + outgoingRecord.Epoch = peer.Epoch; + outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(span.Length); + ++peer.CurrentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(writer); + writer = writer.Slice(Record.Size); + span.CopyTo(writer); + + // Protect the record + peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, span.Length), + ref outgoingRecord + ); + + base.QueueRawData(packet, remoteEndPoint); + } + } + } + + private void HandleStaleConnections(object _) + { + TimeSpan maxAge = TimeSpan.FromSeconds(2.5f); + DateTime now = DateTime.UtcNow; + foreach (KeyValuePair<IPEndPoint, PeerData> kvp in this.existingPeers) + { + PeerData peer = kvp.Value; + lock (peer) + { + if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello) + { + TimeSpan negotiationAge = now - peer.StartOfNegotiation; + if (negotiationAge > maxAge) + { + MarkConnectionAsStale(peer.ConnectionId); + } + } + } + } + + ConnectionId connectionId; + while (this.staleConnections.TryPop(out connectionId)) + { + ThreadLimitedUdpServerConnection connection; + if (this.allConnections.TryGetValue(connectionId, out connection)) + { + connection.Disconnect("Stale Connection", null); + } + } + } + + protected void MarkConnectionAsStale(ConnectionId connectionId) + { + if (this.allConnections.ContainsKey(connectionId)) + { + this.staleConnections.Push(connectionId); + } + } + + /// <inheritdoc /> + internal override void RemovePeerRecord(ConnectionId connectionId) + { + if (this.existingPeers.TryRemove(connectionId.EndPoint, out var peer)) + { + peer.Dispose(); + } + } + + /// <summary> + /// Allocate a new connection id + /// </summary> + private ConnectionId AllocateConnectionId(IPEndPoint endPoint) + { + int rawSerialId = Interlocked.Increment(ref this.connectionSerial_unsafe); + return ConnectionId.Create(endPoint, rawSerialId); + } + + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs new file mode 100644 index 0000000..4da2051 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs @@ -0,0 +1,1246 @@ +using Hazel.Crypto; +using Hazel.Udp; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace Hazel.Dtls +{ + /// <summary> + /// Connects to a UDP-DTLS server + /// </summary> + /// <inheritdoc /> + public class DtlsUnityConnection : UnityUdpClientConnection + { + /// <summary> + /// Current state of the handshake sequence + /// </summary> + enum HandshakeState + { + Initializing, + + ExpectingServerHello, + ExpectingCertificate, + ExpectingServerKeyExchange, + ExpectingServerHelloDone, + + ExpectingChangeCipherSpec, + ExpectingFinished, + + Established, + } + + /// <summary> + /// State data for the current epoch + /// </summary> + struct CurrentEpoch + { + public ulong NextOutgoingSequence; + + public ulong NextExpectedSequence; + public ulong PreviousSequenceWindowBitmask; + + public IRecordProtection RecordProtection; + } + + struct FragmentRange + { + public int Offset; + public int Length; + } + + /// <summary> + /// State data for the next epoch + /// </summary> + struct NextEpoch + { + public ushort Epoch; + + public HandshakeState State; + + public ulong NextOutgoingSequence; + + public DateTime NegotiationStartTime; + public DateTime NextPacketResendTime; + public int PacketResendCount; + + public CipherSuite SelectedCipherSuite; + public IRecordProtection RecordProtection; + public IHandshakeCipherSuite Handshake; + public ByteSpan Cookie; + public Sha256Stream VerificationStream; + public RSA ServerPublicKey; + + public ByteSpan ClientRandom; + public ByteSpan ServerRandom; + + public ByteSpan MasterSecret; + public ByteSpan ServerVerification; + + public List<FragmentRange> CertificateFragments; + public ByteSpan CertificatePayload; + } + + struct QueuedAppData + { + public byte[] Bytes; + public byte SendOption; + public Action AckCallback; + } + + private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP; + + internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion; + + private readonly object syncRoot = new object(); + private readonly RandomNumberGenerator random = RandomNumberGenerator.Create(); + + private ushort epoch; + private CurrentEpoch currentEpoch; + private NextEpoch nextEpoch; + private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200); + + private readonly Queue<QueuedAppData> queuedApplicationData = new Queue<QueuedAppData>(); + + private X509Certificate2Collection serverCertificates = new X509Certificate2Collection(); + + public bool HandshakeComplete + { + get + { + lock (this.syncRoot) + { + return this.nextEpoch.State == HandshakeState.Established; + } + } + } + + /// <summary> + /// Create a new instance of the DTLS connection + /// </summary> + /// <inheritdoc /> + public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) + : base(logger, remoteEndPoint, ipMode) + { + this.nextEpoch.ServerRandom = new byte[Random.Size]; + this.nextEpoch.ClientRandom = new byte[Random.Size]; + this.nextEpoch.ServerVerification = new byte[Finished.Size]; + this.nextEpoch.CertificateFragments = new List<FragmentRange>(); + + this.ResetConnectionState(); + } + + /// <inheritdoc /> + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + lock (this.syncRoot) + { + this.ResetConnectionState(); + } + } + + /// <summary> + /// Set the list of valid server certificates + /// </summary> + /// <param name="certificateCollection"> + /// List of certificates of authentic servers + /// </param> + public void SetValidServerCertificates(X509Certificate2Collection certificateCollection) + { + lock (this.syncRoot) + { + foreach (X509Certificate2 certificate in certificateCollection) + { + if (!(certificate.PublicKey.Key is RSA)) + { + throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection)); + } + } + + this.serverCertificates = certificateCollection; + } + } + + /// <summary> + /// Set the packet resend timer for handshake messages + /// </summary> + public void SetHandshakeResendTimeout(TimeSpan timeout) + { + lock (this.syncRoot) + { + this.handshakeResendTimeout = timeout; + } + } + + /// <summary> + /// Reset existing connection state + /// </summary> + private void ResetConnectionState() + { + this.currentEpoch.NextOutgoingSequence = 1; + this.currentEpoch.NextExpectedSequence = 1; + this.currentEpoch.PreviousSequenceWindowBitmask = 0; + this.currentEpoch.RecordProtection?.Dispose(); + this.currentEpoch.RecordProtection = NullRecordProtection.Instance; + + this.nextEpoch.Epoch = 1; + this.nextEpoch.State = HandshakeState.Initializing; + this.nextEpoch.NextOutgoingSequence = 1; + this.nextEpoch.NegotiationStartTime = DateTime.MinValue; + this.nextEpoch.NextPacketResendTime = DateTime.MinValue; + this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; + this.nextEpoch.RecordProtection?.Dispose(); + this.nextEpoch.RecordProtection = null; + this.nextEpoch.Handshake?.Dispose(); + this.nextEpoch.Handshake = null; + this.nextEpoch.Cookie = ByteSpan.Empty; + this.nextEpoch.VerificationStream?.Dispose(); + this.nextEpoch.VerificationStream = new Sha256Stream(); + this.nextEpoch.ServerPublicKey = null; + this.nextEpoch.ServerRandom.SecureClear(); + this.nextEpoch.ClientRandom.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + this.nextEpoch.ServerVerification.SecureClear(); + this.nextEpoch.CertificateFragments.Clear(); + this.nextEpoch.CertificatePayload = ByteSpan.Empty; + + this.epoch = 0; + while (this.queuedApplicationData.TryDequeue(out _)) ; + } + + /// <summary> + /// Abort the existing connection and restart the process + /// </summary> + protected override void RestartConnection() + { + lock (this.syncRoot) + { + this.ResetConnectionState(); + this.nextEpoch.ClientRandom.FillWithRandom(this.random); + this.SendClientHello(isRetransmit: false); + } + + base.RestartConnection(); + } + + /// <inheritdoc /> + protected override void ResendPacketsIfNeeded() + { + lock (this.syncRoot) + { + // Check if we need to resend handshake message + if (this.nextEpoch.State != HandshakeState.Established) + { + DateTime now = DateTime.UtcNow; + if (now >= this.nextEpoch.NextPacketResendTime) + { + double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds; + this.nextEpoch.PacketResendCount++; + + if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit) + || negotiationDurationMs > this.DisconnectTimeoutMs) + { + this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms)."); + } + else + { + switch (this.nextEpoch.State) + { + case HandshakeState.ExpectingServerHello: + case HandshakeState.ExpectingCertificate: + case HandshakeState.ExpectingServerKeyExchange: + case HandshakeState.ExpectingServerHelloDone: + this.SendClientHello(isRetransmit: true); + break; + + case HandshakeState.ExpectingChangeCipherSpec: + case HandshakeState.ExpectingFinished: + this.SendClientKeyExchangeFlight(isRetransmit: true); + break; + + case HandshakeState.Established: + default: + break; + } + } + } + } + } + + base.ResendPacketsIfNeeded(); + } + + /// <summary> + /// Flush any queued application data packets + /// </summary> + private void FlushQueuedApplicationData() + { + while (this.queuedApplicationData.TryDequeue(out var queuedData)) + { + base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback); + } + } + + /// <summary> + /// Request from the application to write data to the DTLS + /// stream. If appropriate, returns a byte span to send to + /// the wire. + /// </summary> + /// <param name="bytes">Plaintext bytes to write</param> + /// <param name="length">Length of the bytes to write</param> + /// <returns> + /// Encrypted data to put on the wire if appropriate, + /// otherwise an empty span + /// </returns> + private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length) + { + lock (this.syncRoot) + { + Record outgoinRecord = new Record(); + outgoinRecord.ContentType = ContentType.ApplicationData; + outgoinRecord.ProtocolVersion = DtlsVersion; + outgoinRecord.Epoch = this.epoch; + outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length); + ++this.currentEpoch.NextOutgoingSequence; + + // Encode the record to wire format + ByteSpan packet = new byte[Record.Size + outgoinRecord.Length]; + ByteSpan writer = packet; + outgoinRecord.Encode(writer); + writer = writer.Slice(Record.Size); + new ByteSpan(bytes, 0, length).CopyTo(writer); + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoinRecord.Length), + packet.Slice(Record.Size, length), + ref outgoinRecord + ); + + return packet; + } + } + + protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null) + { + lock (this.syncRoot) + { + // If we're negotiating a new epoch, queue data + if (this.nextEpoch.State != HandshakeState.Established) + { + this.queuedApplicationData.Enqueue(new QueuedAppData + { + Bytes = data, + SendOption = sendOption, + AckCallback = ackCallback + }); + + return; + } + } + + base.HandleSend(data, sendOption, ackCallback); + } + + /// <inheritdoc /> + protected override void WriteBytesToConnection(byte[] bytes, int length) + { + ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); + if (wireData.Length > 0) + { + Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); + base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length); + } + } + + /// <inheritdoc /> + protected override void WriteBytesToConnectionSync(byte[] bytes, int length) + { + ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); + if (wireData.Length > 0) + { + Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); + base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length); + } + } + + /// <inheritdoc /> + protected internal override void HandleReceive(MessageReader reader, int bytesReceived) + { + ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining); + lock (this.syncRoot) + { + this.HandleReceive(message); + } + + reader.Recycle(); + } + + /// <summary> + /// Handle an incoming datagram + /// </summary> + /// <param name="span">Bytes of the datagram</param> + private void HandleReceive(ByteSpan span) + { + // Each incoming packet may contain multiple DTLS + // records + while (span.Length > 0) + { + Record record; + if (!Record.Parse(out record, DtlsVersion, span)) + { + this.logger.WriteError("Dropping malformed record"); + return; + } + span = span.Slice(Record.Size); + + if (span.Length < record.Length) + { + this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})"); + return; + } + + ByteSpan recordPayload = span.Slice(0, record.Length); + span = span.Slice(record.Length); + + // Early out and drop ApplicationData records + if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established) + { + this.logger.WriteError("Dropping ApplicationData record. Cannot process yet"); + continue; + } + + // Drop records from a different epoch + if (record.Epoch != this.epoch) + { + this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})"); + continue; + } + + // Prevent replay attacks by dropping records + // we've already processed + int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1); + ulong windowMask = 1ul << windowIndex; + if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence) + { + if (windowIndex >= 64) + { + this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})"); + continue; + } + + if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0) + { + this.logger.WriteWarning("Dropping duplicate record"); + continue; + } + } + + // Verify record authenticity + int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length); + ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize); + + if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record)) + { + this.logger.WriteError("Dropping non-authentic record"); + return; + } + + recordPayload = decryptedPayload; + + // Update out sequence number bookkeeping + if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence) + { + int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence); + this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift; + this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1; + } + else + { + this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask; + } + + // This is handy for debugging, but too verbose even for verbose. + // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})"); + switch (record.ContentType) + { + case ContentType.ChangeCipherSpec: + if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec) + { + this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})"); + break; + } + else if (this.nextEpoch.RecordProtection == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client. + Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?"); + break; + } + + if (!ChangeCipherSpec.Parse(recordPayload)) + { + this.logger.WriteError("Dropping malformed ChangeCipherSpec message"); + break; + } + + // Migrate to the next epoch + this.epoch = this.nextEpoch.Epoch; + this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection; + this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence; + this.currentEpoch.NextExpectedSequence = 1; + this.currentEpoch.PreviousSequenceWindowBitmask = 0; + + this.nextEpoch.State = HandshakeState.ExpectingFinished; + this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; + this.nextEpoch.RecordProtection = null; + this.nextEpoch.Handshake?.Dispose(); + this.nextEpoch.Cookie = ByteSpan.Empty; + this.nextEpoch.VerificationStream.Reset(); + this.nextEpoch.ServerPublicKey = null; + this.nextEpoch.ServerRandom.SecureClear(); + this.nextEpoch.ClientRandom.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + break; + + case ContentType.Alert: + this.logger.WriteError("Dropping unsupported alert record"); + continue; + + case ContentType.Handshake: + if (!ProcessHandshake(ref record, recordPayload)) + { + return; + } + break; + + case ContentType.ApplicationData: + // Forward data to the application + MessageReader reader = MessageReader.GetSized(recordPayload.Length); + reader.Length = recordPayload.Length; + recordPayload.CopyTo(reader.Buffer); + + base.HandleReceive(reader, recordPayload.Length); + break; + } + } + } + + /// <summary> + /// Process an incoming Handshake protocol message + /// </summary> + /// <param name="record">Parent record</param> + /// <param name="message">Record payload</param> + /// <returns> + /// True if further processing of the underlying datagram + /// should be continues. Otherwise, false. + /// </returns> + private bool ProcessHandshake(ref Record record, ByteSpan message) + { + // Each record may have multiple Handshake messages + while (message.Length > 0) + { + ByteSpan originalPayload = message; + + Handshake handshake; + if (!Handshake.Parse(out handshake, message)) + { + this.logger.WriteError("Dropping malformed handshake message"); + return false; + } + message = message.Slice(Handshake.Size); + + // Check for fragmented messages + if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length) + { + // We only support fragmentation on Certificate messages + if (handshake.MessageType != HandshakeType.Certificate) + { + this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})"); + continue; + } + + if (message.Length < handshake.FragmentLength) + { + this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})"); + return false; + } + + originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength)); + message = message.Slice((int)handshake.FragmentLength); + } + else + { + if (message.Length < handshake.Length) + { + this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})"); + return false; + } + + originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length)); + message = message.Slice((int)handshake.Length); + } + + ByteSpan payload = originalPayload.Slice(Handshake.Size); + +#if DEBUG + this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})"); +#endif + switch (handshake.MessageType) + { + case HandshakeType.HelloVerifyRequest: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) + { + this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 0) + { + this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})"); + continue; + } + + HelloVerifyRequest helloVerifyRequest; + if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload)) + { + this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message"); + continue; + } + + // If the cookie differs, save it and restart the handshake + if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length + && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1) + { + this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message"); + continue; + } + + this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length]; + helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie); + this.nextEpoch.ClientRandom.FillWithRandom(this.random); + + // We don't need to resend here. We already have the cookie so we already sent it once. + this.SendClientHello(isRetransmit: false); + + break; + + case HandshakeType.ServerHello: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) + { + this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 1) + { + this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})"); + continue; + } + + ServerHello serverHello; + if (!ServerHello.Parse(out serverHello, payload)) + { + this.logger.WriteError("Dropping malformed ServerHello message"); + continue; + } + + switch (serverHello.CipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random); + break; + + default: + this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})"); + continue; + } + + // Save server parameters + this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite; + serverHello.Random.CopyTo(this.nextEpoch.ServerRandom); + this.nextEpoch.State = HandshakeState.ExpectingCertificate; + this.nextEpoch.CertificateFragments.Clear(); + this.nextEpoch.CertificatePayload = ByteSpan.Empty; + +#if DEBUG + this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}"); +#endif + + // Append ServerHelllo message to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + break; + + case HandshakeType.Certificate: + if (this.nextEpoch.State != HandshakeState.ExpectingCertificate) + { + this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 2) + { + this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})"); + continue; + } + + // If this is a fragmented message + if (handshake.FragmentLength != handshake.Length) + { + if (this.nextEpoch.CertificatePayload.Length != handshake.Length) + { + this.nextEpoch.CertificatePayload = new byte[handshake.Length]; + this.nextEpoch.CertificateFragments.Clear(); + } + + // Should we add this fragment? + // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments + // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt? + // So let's just ignore that and assume that the sender always wants to send the same fragments. + if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength)) + { + continue; + } + + payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength)); + this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength }); + this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => { + return lhs.Offset.CompareTo(rhs.Offset); + }); + + // Have we completed the message? + int currentOffset = 0; + bool valid = true; + foreach (FragmentRange range in this.nextEpoch.CertificateFragments) + { + if (range.Offset != currentOffset) + { + valid = false; + break; + } + + currentOffset += range.Length; + } + + if (currentOffset != this.nextEpoch.CertificatePayload.Length) + { + valid = false; + } + + // Still waiting on more fragments? + if (!valid) + { + continue; + } + + // Replace the message payload, and continue + this.nextEpoch.CertificateFragments.Clear(); + payload = this.nextEpoch.CertificatePayload; + } + + X509Certificate2 certificate; + if (!Certificate.Parse(out certificate, payload)) + { + this.logger.WriteError("Dropping malformed Certificate message"); + continue; + } + + // Verify the certificate is authenticate + if (!this.serverCertificates.Contains(certificate)) + { + this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic"); + continue; + } + + RSA publicKey = certificate.PublicKey.Key as RSA; + if (publicKey == null) + { + this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed"); + continue; + } + + // Add the final Certificate message to the verification stream + Handshake fullCertificateHandhake = handshake; + fullCertificateHandhake.FragmentOffset = 0; + fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length; + + ByteSpan serializedCertificateHandshake = new byte[Handshake.Size]; + fullCertificateHandhake.Encode(serializedCertificateHandshake); + this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake); + this.nextEpoch.VerificationStream.AddData(payload); + + this.nextEpoch.ServerPublicKey = publicKey; + this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange; + break; + + case HandshakeType.ServerKeyExchange: + if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange) + { + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})"); + continue; + } + else if (this.nextEpoch.ServerPublicKey == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client + Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?"); + + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key"); + continue; + } + else if (this.nextEpoch.Handshake == null) + { + ///NOTE(mendsley): This _should_ not + /// happen on a well-formed client + Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?"); + + this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface"); + continue; + } + else if (handshake.MessageSequence != 3) + { + this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})"); + continue; + } + + ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()]; + if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey)) + { + this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage"); + return false; + } + + // Generate the session master secret + ByteSpan randomSeed = new byte[2 * Random.Size]; + this.nextEpoch.ClientRandom.CopyTo(randomSeed); + this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size)); + + const int MasterSecretSize = 48; + ByteSpan masterSecret = new byte[MasterSecretSize]; + PrfSha256.ExpandSecret( + masterSecret + , sharedSecret + , PrfLabel.MASTER_SECRET + , randomSeed + ); + + // Create record protection for the upcoming epoch + switch (this.nextEpoch.SelectedCipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection( + masterSecret + , this.nextEpoch.ServerRandom + , this.nextEpoch.ClientRandom + ); + break; + + default: + ///NOTE(mendsley): this _should_ not + /// happen on a well-formed client. + Debug.Assert(false, "SeverHello processing already approved this ciphersuite"); + + this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection"); + return false; + } + + this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone; + this.nextEpoch.MasterSecret = masterSecret; + + // Append ServerKeyExchange to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + break; + + case HandshakeType.ServerHelloDone: + if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone) + { + this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})"); + continue; + } + else if (handshake.MessageSequence != 4) + { + this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})"); + continue; + } + + this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + + // Append ServerHelloDone to the verification stream + this.nextEpoch.VerificationStream.AddData(originalPayload); + + this.SendClientKeyExchangeFlight(isRetransmit: false); + break; + + case HandshakeType.Finished: + if (this.nextEpoch.State != HandshakeState.ExpectingFinished) + { + this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})"); + continue; + } + else if (payload.Length != Finished.Size) + { + this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})"); + continue; + } + else if (handshake.MessageSequence != 7) + { + this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})"); + continue; + } + + // Verify the digest from the server + if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification)) + { + this.logger.WriteError("Dropping non-verified Finished handshake message"); + return false; + } + + ++this.nextEpoch.Epoch; + this.nextEpoch.State = HandshakeState.Established; + this.nextEpoch.NegotiationStartTime = DateTime.MinValue; + this.nextEpoch.NextPacketResendTime = DateTime.MinValue; + this.nextEpoch.ServerVerification.SecureClear(); + this.nextEpoch.MasterSecret.SecureClear(); + + this.FlushQueuedApplicationData(); + break; + + // Drop messages we do not support + case HandshakeType.CertificateRequest: + case HandshakeType.HelloRequest: + this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})"); + break; + + // Drop messages that originate from the client + case HandshakeType.ClientHello: + case HandshakeType.ClientKeyExchange: + case HandshakeType.CertificateVerify: + this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})"); + break; + } + } + + return true; + } + + private bool IsFragmentOverlapping(List<FragmentRange> fragments, uint newOffset, uint newLength) + { + foreach (var frag in fragments) + { + // New fragment overlaps an existing one + if (newOffset <= frag.Offset + && frag.Offset < newOffset + newLength) + { + return true; + } + + // Existing fragment overlaps this new one + if (frag.Offset <= newOffset + && newOffset < frag.Offset + frag.Length) + { + return true; + } + } + + return false; + } + + /// <summary> + /// Send (resend) a ClientHello message to the server + /// </summary> + protected virtual void SendClientHello(bool isRetransmit) + { +#if DEBUG + var verb = isRetransmit ? "Resending" : "Sending"; + this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}"); +#endif + + // Describe our ClientHello flight + ClientHello clientHello = new ClientHello(); + clientHello.ClientProtocolVersion = DtlsVersion; + clientHello.Random = this.nextEpoch.ClientRandom; + clientHello.Cookie = this.nextEpoch.Cookie; + clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion); + clientHello.CipherSuites = new byte[2]; + clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); + clientHello.SupportedCurves = new byte[2]; + clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); + + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.ClientHello; + handshake.Length = (uint)clientHello.CalculateSize(); + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + // Describe the record + int plaintextLength = (int)(Handshake.Size + handshake.Length); + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.Handshake; + outgoingRecord.ProtocolVersion = DtlsVersion; + outgoingRecord.Epoch = this.epoch; + outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); + ++this.currentEpoch.NextOutgoingSequence; + + // Convert the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(packet); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + clientHello.Encode(writer); + + // If this is our first valid attempt at contacting the server: + // - Reset our verification stream + // - Write ClientHello to the verification stream + // - We next expect a ServerHello + // + // ClientHello+Cookie triggers many sequential packets in response + // It's important to make forward progress as the packets may be reordered in-flight + // But with enough resends, we will read them all in an appropriate order + if (!isRetransmit) + { + this.nextEpoch.VerificationStream.Reset(); + this.nextEpoch.VerificationStream.AddData( + packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length) + ); + + this.nextEpoch.State = HandshakeState.ExpectingServerHello; + } + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, plaintextLength), + ref outgoingRecord + ); + + if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; + + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + protected void Test_SendClientHello(Func<ClientHello, ByteSpan, ByteSpan> encodeCallback) + { + // Reset our verification stream + this.nextEpoch.VerificationStream.Reset(); + + // Describe our ClientHello flight + ClientHello clientHello = new ClientHello(); + clientHello.Random = this.nextEpoch.ClientRandom; + clientHello.Cookie = this.nextEpoch.Cookie; + clientHello.CipherSuites = new byte[2]; + clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); + clientHello.SupportedCurves = new byte[2]; + clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); + + Handshake handshake = new Handshake(); + handshake.MessageType = HandshakeType.ClientHello; + handshake.Length = (uint)clientHello.CalculateSize(); + handshake.MessageSequence = 0; + handshake.FragmentOffset = 0; + handshake.FragmentLength = handshake.Length; + + // Describe the record + int plaintextLength = (int)(Handshake.Size + handshake.Length); + Record outgoingRecord = new Record(); + outgoingRecord.ContentType = ContentType.Handshake; + outgoingRecord.ProtocolVersion = DtlsVersion; + outgoingRecord.Epoch = this.epoch; + outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); + ++this.currentEpoch.NextOutgoingSequence; + + // Convert the record to wire format + ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; + ByteSpan writer = packet; + outgoingRecord.Encode(packet); + writer = writer.Slice(Record.Size); + handshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + writer = encodeCallback(clientHello, writer); + + // Write ClientHello to the verification stream + this.nextEpoch.VerificationStream.AddData( + packet.Slice( + Record.Size + , Handshake.Size + (int)handshake.Length + ) + ); + + // Protect the record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, outgoingRecord.Length), + packet.Slice(Record.Size, plaintextLength), + ref outgoingRecord + ); + + this.nextEpoch.State = HandshakeState.ExpectingServerHello; + if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + /// <summary> + /// Send (resend) the ClientKeyExchange flight + /// </summary> + /// <param name="isRetransmit"> + /// True if this is a retransmit of the flight. Otherwise, + /// false + /// </param> + protected virtual void SendClientKeyExchangeFlight(bool isRetransmit) + { +#if DEBUG + var verb = isRetransmit ? "Resending" : "Sending"; + this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}"); +#endif + if (this.nextEpoch.State == HandshakeState.Established) + { + return; + } + + // Describe our flight + Handshake keyExchangeHandshake = new Handshake(); + keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange; + keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize(); + keyExchangeHandshake.MessageSequence = 5; + keyExchangeHandshake.FragmentOffset = 0; + keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length; + + Record keyExchangeRecord = new Record(); + keyExchangeRecord.ContentType = ContentType.Handshake; + keyExchangeRecord.ProtocolVersion = DtlsVersion; + keyExchangeRecord.Epoch = this.epoch; + keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length); + ++this.currentEpoch.NextOutgoingSequence; + + Record changeCipherSpecRecord = new Record(); + changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec; + changeCipherSpecRecord.ProtocolVersion = DtlsVersion; + changeCipherSpecRecord.Epoch = this.epoch; + changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; + changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size); + ++this.currentEpoch.NextOutgoingSequence; + + Handshake finishedHandshake = new Handshake(); + finishedHandshake.MessageType = HandshakeType.Finished; + finishedHandshake.Length = Finished.Size; + finishedHandshake.MessageSequence = 6; + finishedHandshake.FragmentOffset = 0; + finishedHandshake.FragmentLength = finishedHandshake.Length; + + Record finishedRecord = new Record(); + finishedRecord.ContentType = ContentType.Handshake; + finishedRecord.ProtocolVersion = DtlsVersion; + finishedRecord.Epoch = this.nextEpoch.Epoch; + finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence; + finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length); + ++this.nextEpoch.NextOutgoingSequence; + + // Encode flight to wire format + int packetLength = 0 + + Record.Size + keyExchangeRecord.Length + + Record.Size + changeCipherSpecRecord.Length + + Record.Size + finishedRecord.Length; + ; + ByteSpan packet = new byte[packetLength]; + ByteSpan writer = packet; + + keyExchangeRecord.Encode(writer); + writer = writer.Slice(Record.Size); + keyExchangeHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer); + + ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length); + writer = startOfChangeCipherSpecRecord; + changeCipherSpecRecord.Encode(writer); + writer = writer.Slice(Record.Size); + ChangeCipherSpec.Encode(writer); + writer = writer.Slice(ChangeCipherSpec.Size); + + ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length); + writer = startOfFinishedRecord; + finishedRecord.Encode(writer); + writer = writer.Slice(Record.Size); + finishedHandshake.Encode(writer); + writer = writer.Slice(Handshake.Size); + + // Interject here to writer our client key exchange + // message into the verification stream + if (!isRetransmit) + { + this.nextEpoch.VerificationStream.AddData( + packet.Slice( + Record.Size + , Handshake.Size + (int)keyExchangeHandshake.Length + ) + ); + } + + // Calculate the hash of the verification stream + ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize]; + this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash); + + // Expand our master secret into Finished digests for the client and server + PrfSha256.ExpandSecret( + this.nextEpoch.ServerVerification + , this.nextEpoch.MasterSecret + , PrfLabel.SERVER_FINISHED + , handshakeHash + ); + + PrfSha256.ExpandSecret( + writer.Slice(0, Finished.Size) + , this.nextEpoch.MasterSecret + , PrfLabel.CLIENT_FINISHED + , handshakeHash + ); + writer = writer.Slice(Finished.Size); + + // Protect the ClientKeyExchange record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + packet.Slice(Record.Size, keyExchangeRecord.Length), + packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length), + ref keyExchangeRecord + ); + + // Protect the ChangeCipherSpec record + this.currentEpoch.RecordProtection.EncryptClientPlaintext( + startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length), + startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size), + ref changeCipherSpecRecord + ); + + // Protect the Finished record + this.nextEpoch.RecordProtection.EncryptClientPlaintext( + startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length), + startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length), + ref finishedRecord + ); + + this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; + this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; +#if DEBUG + if (DropClientKeyExchangeFlight()) + { + return; + } +#endif + base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); + } + + protected virtual bool DropClientKeyExchangeFlight() + { + return false; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs new file mode 100644 index 0000000..f840053 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs @@ -0,0 +1,734 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +namespace Hazel.Dtls +{ + /// <summary> + /// Handshake message type + /// </summary> + public enum HandshakeType : byte + { + HelloRequest = 0, + ClientHello = 1, + ServerHello = 2, + HelloVerifyRequest = 3, + Certificate = 11, + ServerKeyExchange = 12, + CertificateRequest = 13, + ServerHelloDone = 14, + CertificateVerify = 15, + ClientKeyExchange = 16, + Finished = 20, + } + + /// <summary> + /// List of cipher suites + /// </summary> + public enum CipherSuite + { + TLS_NULL_WITH_NULL_NULL = 0x0000, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F, + } + + /// <summary> + /// List of compression methods + /// </summary> + public enum CompressionMethod : byte + { + Null = 0, + } + + /// <summary> + /// Extension type + /// </summary> + public enum ExtensionType : ushort + { + EllipticCurves = 10, + } + + /// <summary> + /// Named curves + /// </summary> + public enum NamedCurve : ushort + { + Reserved = 0, + secp256r1 = 23, + x25519 = 29, + } + + /// <summary> + /// Elliptic curve type + /// </summary> + public enum ECCurveType : byte + { + NamedCurve = 3, + } + + /// <summary> + /// Hash algorithms + /// </summary> + public enum HashAlgorithm : byte + { + None = 0, + Sha256 = 4, + } + + /// <summary> + /// Signature algorithms + /// </summary> + public enum SignatureAlgorithm : byte + { + Anonymous = 0, + RSA = 1, + ECDSA = 3, + } + + /// <summary> + /// Random state for entropy + /// </summary> + public struct Random + { + public const int Size = 0 + + 4 // gmt_unix_time + + 28 // random_bytes + ; + } + + /// <summary> + /// Encode/decode handshake protocol header + /// </summary> + public struct Handshake + { + public HandshakeType MessageType; + public uint Length; + public ushort MessageSequence; + public uint FragmentOffset; + public uint FragmentLength; + + public const int Size = 12; + + /// <summary> + /// Parse a Handshake protocol header from wire format + /// </summary> + /// <returns>True if we successfully decode a handshake header. Otherwise false</returns> + public static bool Parse(out Handshake header, ByteSpan span) + { + header = new Handshake(); + + if (span.Length < Size) + { + return false; + } + + header.MessageType = (HandshakeType)span[0]; + header.Length = span.ReadBigEndian24(1); + header.MessageSequence = span.ReadBigEndian16(4); + header.FragmentOffset = span.ReadBigEndian24(6); + header.FragmentLength = span.ReadBigEndian24(9); + return true; + } + + /// <summary> + /// Encode the Handshake protocol header to wire format + /// </summary> + /// <param name="span"></param> + public void Encode(ByteSpan span) + { + span[0] = (byte)this.MessageType; + span.WriteBigEndian24(this.Length, 1); + span.WriteBigEndian16(this.MessageSequence, 4); + span.WriteBigEndian24(this.FragmentOffset, 6); + span.WriteBigEndian24(this.FragmentLength, 9); + } + } + + /// <summary> + /// Encode/decode ClientHello Handshake message + /// </summary> + public struct ClientHello + { + public ProtocolVersion ClientProtocolVersion; + public ByteSpan Random; + public ByteSpan Cookie; + public HazelDtlsSessionInfo Session; + public ByteSpan CipherSuites; + public ByteSpan SupportedCurves; + + public const int MinSize = 0 + + 2 // client_version + + Dtls.Random.Size // random + + 1 // session_id (size) + + 1 // cookie (size) + + 2 // cipher_suites (size) + + 1 // compression_methods (size) + + 1 // compression_method[0] (NULL) + + + 2 // extensions size + + + 0 // NamedCurveList extensions[0] + + 2 // extensions[0].extension_type + + 2 // extensions[0].extension_data (length) + + 2 // extensions[0].named_curve_list (size) + ; + + /// <summary> + /// Calculate the size in bytes required for the ClientHello payload + /// </summary> + /// <returns></returns> + public int CalculateSize() + { + return MinSize + + this.Session.PayloadSize + + this.Cookie.Length + + this.CipherSuites.Length + + this.SupportedCurves.Length + ; + } + + /// <summary> + /// Parse a Handshake ClientHello payload from wire format + /// </summary> + /// <returns>True if we successfully decode the ClientHello message. Otherwise false</returns> + public static bool Parse(out ClientHello result, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + result = new ClientHello(); + if (span.Length < MinSize) + { + return false; + } + + result.ClientProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(); + if (expectedProtocolVersion.HasValue && result.ClientProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + span = span.Slice(2); + + result.Random = span.Slice(0, Dtls.Random.Size); + span = span.Slice(Dtls.Random.Size); + + if (!HazelDtlsSessionInfo.Parse(out result.Session, span)) + { + return false; + } + + span = span.Slice(result.Session.FullSize); + + byte cookieSize = span[0]; + if (span.Length < 1 + cookieSize) + { + return false; + } + result.Cookie = span.Slice(1, cookieSize); + span = span.Slice(1 + cookieSize); + + ushort cipherSuiteSize = span.ReadBigEndian16(); + if (span.Length < 2 + cipherSuiteSize) + { + return false; + } + else if (cipherSuiteSize % 2 != 0) + { + return false; + } + result.CipherSuites = span.Slice(2, cipherSuiteSize); + span = span.Slice(2 + cipherSuiteSize); + + int compressionMethodsSize = span[0]; + bool foundNullCompressionMethod = false; + for (int ii = 0; ii != compressionMethodsSize; ++ii) + { + if (span[1+ii] == (byte)CompressionMethod.Null) + { + foundNullCompressionMethod = true; + break; + } + } + + if (!foundNullCompressionMethod + || span.Length < 1 + compressionMethodsSize) + { + return false; + } + + span = span.Slice(1 + compressionMethodsSize); + + // Parse extensions + if (span.Length > 0) + { + if (span.Length < 2) + { + return false; + } + + ushort extensionsSize = span.ReadBigEndian16(); + span = span.Slice(2); + if (span.Length != extensionsSize) + { + return false; + } + + while (span.Length > 0) + { + // Parse extension header + if (span.Length < 4) + { + return false; + } + + ExtensionType extensionType = (ExtensionType)span.ReadBigEndian16(0); + ushort extensionLength = span.ReadBigEndian16(2); + + if (span.Length < 4 + extensionLength) + { + return false; + } + + ByteSpan extensionData = span.Slice(4, extensionLength); + span = span.Slice(4 + extensionLength); + result.ParseExtension(extensionType, extensionData); + } + } + + return true; + } + + /// <summary> + /// Decode a ClientHello extension + /// </summary> + /// <param name="extensionType">Extension type</param> + /// <param name="extensionData">Extension data</param> + private void ParseExtension(ExtensionType extensionType, ByteSpan extensionData) + { + switch (extensionType) + { + case ExtensionType.EllipticCurves: + if (extensionData.Length % 2 != 0) + { + break; + } + else if (extensionData.Length < 2) + { + break; + } + + ushort namedCurveSize = extensionData.ReadBigEndian16(0); + if (namedCurveSize % 2 != 0) + { + break; + } + + this.SupportedCurves = extensionData.Slice(2, namedCurveSize); + break; + } + } + + /// <summary> + /// Determines if the ClientHello message advertises support + /// for the specified cipher suite + /// </summary> + public bool ContainsCipherSuite(CipherSuite cipherSuite) + { + ByteSpan iterator = this.CipherSuites; + while (iterator.Length >= 2) + { + if (iterator.ReadBigEndian16() == (ushort)cipherSuite) + { + return true; + } + + iterator = iterator.Slice(2); + } + + return false; + } + + /// <summary> + /// Determines if the ClientHello message advertises support + /// for the specified curve + /// </summary> + public bool ContainsCurve(NamedCurve curve) + { + ByteSpan iterator = this.SupportedCurves; + while (iterator.Length >= 2) + { + if (iterator.ReadBigEndian16() == (ushort)curve) + { + return true; + } + + iterator = iterator.Slice(2); + } + + return false; + } + + /// <summary> + /// Encode Handshake ClientHello payload to wire format + /// </summary> + public void Encode(ByteSpan span) + { + span.WriteBigEndian16((ushort)this.ClientProtocolVersion); + span = span.Slice(2); + + Debug.Assert(this.Random.Length == Dtls.Random.Size); + this.Random.CopyTo(span); + span = span.Slice(Dtls.Random.Size); + + this.Session.Encode(span); + span = span.Slice(this.Session.FullSize); + + span[0] = (byte)this.Cookie.Length; + this.Cookie.CopyTo(span.Slice(1)); + span = span.Slice(1 + this.Cookie.Length); + + span.WriteBigEndian16((ushort)this.CipherSuites.Length); + this.CipherSuites.CopyTo(span.Slice(2)); + span = span.Slice(2 + this.CipherSuites.Length); + + span[0] = 1; + span[1] = (byte)CompressionMethod.Null; + span = span.Slice(2); + + // Extensions size + span.WriteBigEndian16((ushort)(6 + this.SupportedCurves.Length)); + span = span.Slice(2); + + // Supported curves extension + span.WriteBigEndian16((ushort)ExtensionType.EllipticCurves); + span.WriteBigEndian16((ushort)(2 + this.SupportedCurves.Length), 2); + span.WriteBigEndian16((ushort)this.SupportedCurves.Length, 4); + this.SupportedCurves.CopyTo(span.Slice(6)); + } + } + + /// <summary> + /// Encode/Decode session information in ClientHello + /// </summary> + public struct HazelDtlsSessionInfo + { + public const byte CurrentClientSessionSize = 1; + public const byte CurrentClientSessionVersion = 1; + + public byte FullSize => (byte)(1 + this.PayloadSize); + public byte PayloadSize; + public byte Version; + + public HazelDtlsSessionInfo(byte version) + { + this.Version = version; + switch (version) + { + case 0: // Does not write version byte + this.PayloadSize = 0; + return; + case 1: // Writes version byte only + this.PayloadSize = 1; + return; + } + + throw new ArgumentOutOfRangeException("Unimplemented Hazel session version"); + } + + public void Encode(ByteSpan writer) + { + writer[0] = this.PayloadSize; + + if (this.Version > 0) + { + writer[1] = this.Version; + } + } + + public static bool Parse(out HazelDtlsSessionInfo result, ByteSpan reader) + { + result = new HazelDtlsSessionInfo(); + if (reader.Length < 1) + { + return false; + } + + result.PayloadSize = reader[0]; + + // Back compat, length may be zero, version defaults to 0. + if (result.PayloadSize == 0) + { + result.Version = 0; + return true; + } + + // Forward compat, if length > 1, ignore the rest + result.Version = reader[1]; + return true; + } + } + + /// <summary> + /// Encode/decode Handshake HelloVerifyRequest message + /// </summary> + public struct HelloVerifyRequest + { + public const int CookieSize = 20; + public const int Size = 0 + + 2 // server_version + + 1 // cookie (size) + + CookieSize // cookie (data) + ; + + public ProtocolVersion ServerProtocolVersion; + public ByteSpan Cookie; + + /// <summary> + /// Parse a Handshake HelloVerifyRequest payload from wire + /// format + /// </summary> + /// <returns> + /// True if we successfully decode the HelloVerifyRequest + /// message. Otherwise false. + /// </returns> + public static bool Parse(out HelloVerifyRequest result, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + result = new HelloVerifyRequest(); + if (span.Length < 3) + { + return false; + } + + result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(0); + if (expectedProtocolVersion.HasValue && result.ServerProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + byte cookieSize = span[2]; + span = span.Slice(3); + + if (span.Length < cookieSize) + { + return false; + } + + result.Cookie = span; + return true; + } + + /// <summary> + /// Encode a HelloVerifyRequest payload to wire format + /// </summary> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + public static void Encode(ByteSpan span, EndPoint peerAddress, HMAC hmac, ProtocolVersion protocolVersion) + { + ByteSpan cookie = ComputeAddressMac(peerAddress, hmac); + + span.WriteBigEndian16((ushort)protocolVersion); + span[2] = (byte)CookieSize; + cookie.CopyTo(span.Slice(3)); + } + + /// <summary> + /// Generate an HMAC for a peer address + /// </summary> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + public static ByteSpan ComputeAddressMac(EndPoint peerAddress, HMAC hmac) + { + SocketAddress address = peerAddress.Serialize(); + byte[] data = new byte[address.Size]; + for (int ii = 0, nn = data.Length; ii != nn; ++ii) + { + data[ii] = address[ii]; + } + + ///NOTE(mendsley): Lame that we need to allocate+copy here + ByteSpan signature = hmac.ComputeHash(data); + return signature.Slice(0, CookieSize); + } + + /// <summary> + /// Verify a client's cookie was signed by our listener + /// </summary> + /// <param name="cookie">Wire format cookie</param> + /// <param name="peerAddress">Address of the remote peer</param> + /// <param name="hmac">Listener HMAC signature provider</param> + /// <returns>True if the cookie is valid. Otherwise false</returns> + public static bool VerifyCookie(ByteSpan cookie, EndPoint peerAddress, HMAC hmac) + { + if (cookie.Length != CookieSize) + { + return false; + } + + ByteSpan expectedHash = ComputeAddressMac(peerAddress, hmac); + if (expectedHash.Length != cookie.Length) + { + return false; + } + + return (1 == Crypto.Const.ConstantCompareSpans(cookie, expectedHash)); + } + } + + /// <summary> + /// Encode/decode Handshake ServerHello message + /// </summary> + public struct ServerHello + { + public ProtocolVersion ServerProtocolVersion; + public ByteSpan Random; + public CipherSuite CipherSuite; + public HazelDtlsSessionInfo Session; + + public const int MinSize = 0 + + 2 // server_version + + Dtls.Random.Size // random + + 1 // session_id (size) + + 2 // cipher_suite + + 1 // compression_method + ; + + public int Size => MinSize + Session.PayloadSize; + + /// <summary> + /// Parse a Handshake ServerHello payload from wire format + /// </summary> + /// <returns> + /// True if we successfully decode the ServerHello + /// message. Otherwise false. + /// </returns> + public static bool Parse(out ServerHello result, ByteSpan span) + { + result = new ServerHello(); + if (span.Length < MinSize) + { + return false; + } + + result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(); + span = span.Slice(2); + + result.Random = span.Slice(0, Dtls.Random.Size); + span = span.Slice(Dtls.Random.Size); + + if (!HazelDtlsSessionInfo.Parse(out result.Session, span)) + { + return false; + } + + span = span.Slice(result.Session.FullSize); + + result.CipherSuite = (CipherSuite)span.ReadBigEndian16(); + span = span.Slice(2); + + CompressionMethod compressionMethod = (CompressionMethod)span[0]; + if (compressionMethod != CompressionMethod.Null) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode Handshake ServerHello to wire format + /// </summary> + public void Encode(ByteSpan span) + { + Debug.Assert(this.Random.Length == Dtls.Random.Size); + + span.WriteBigEndian16((ushort)this.ServerProtocolVersion, 0); + span = span.Slice(2); + + this.Random.CopyTo(span); + span = span.Slice(Dtls.Random.Size); + + this.Session.Encode(span); + span = span.Slice(this.Session.FullSize); + + span.WriteBigEndian16((ushort)this.CipherSuite); + span = span.Slice(2); + + span[0] = (byte)CompressionMethod.Null; + } + } + + /// <summary> + /// Encode/decode Handshake Certificate message + /// </summary> + public struct Certificate + { + /// <summary> + /// Encode a certificate to wire formate + /// </summary> + public static ByteSpan Encode(X509Certificate2 certificate) + { + ByteSpan certData = certificate.GetRawCertData(); + int totalSize = certData.Length + 3 + 3; + + ByteSpan result = new byte[totalSize]; + + ByteSpan writer = result; + writer.WriteBigEndian24((uint)certData.Length + 3); + writer = writer.Slice(3); + writer.WriteBigEndian24((uint)certData.Length); + writer = writer.Slice(3); + + certData.CopyTo(writer); + return result; + } + + /// <summary> + /// Parse a Handshake Certificate payload from wire format + /// </summary> + /// <returns>True if we successfully decode the Certificate message. Otherwise false</returns> + public static bool Parse(out X509Certificate2 certificate, ByteSpan span) + { + certificate = null; + if (span.Length < 6) + { + return false; + } + + uint totalSize = span.ReadBigEndian24(); + span = span.Slice(3); + + if (span.Length < totalSize) + { + return false; + } + + uint certificateSize = span.ReadBigEndian24(); + span = span.Slice(3); + if (span.Length < certificateSize) + { + return false; + } + + byte[] rawData = new byte[certificateSize]; + span.CopyTo(rawData, 0); + try + { + certificate = new X509Certificate2(rawData); + } + catch (Exception) + { + return false; + } + + return true; + } + } + + /// <summary> + /// Encode/decode Handshake Finished message + /// </summary> + public struct Finished + { + public const int Size = 12; + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs new file mode 100644 index 0000000..eedd977 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs @@ -0,0 +1,63 @@ +using System; + +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS cipher suite interface for the handshake portion of + /// the connection. + /// </summary> + public interface IHandshakeCipherSuite : IDisposable + { + /// <summary> + /// Gets the size of the shared key + /// </summary> + /// <returns>Size of the shared key in bytes </returns> + int SharedKeySize(); + + /// <summary> + /// Calculate the size of the ServerKeyExchnage message + /// </summary> + /// <param name="privateKey"> + /// Private key that will be used to sign the message + /// </param> + /// <returns>Size of the message in bytes</returns> + int CalculateServerMessageSize(object privateKey); + + /// <summary> + /// Encodes the ServerKeyExchange message + /// </summary> + /// <param name="privateKey">Private key to use for signing</param> + void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey); + + /// <summary> + /// Verifies the authenticity of a server key exchange + /// message and calculates the shared secret. + /// </summary> + /// <returns> + /// True if the authenticity has been validated and a shared key + /// was generated. Otherwise, false. + /// </returns> + bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey); + + /// <summary> + /// Calculate the size of the ClientKeyExchange message + /// </summary> + /// <returns>Size of the message in bytes</returns> + int CalculateClientMessageSize(); + + /// <summary> + /// Encodes the ClientKeyExchangeMessage + /// </summary> + void EncodeClientKeyExchangeMessage(ByteSpan output); + + /// <summary> + /// Verifies the validity of a client key exchange message + /// and calculats the hsared secret. + /// </summary> + /// <returns> + /// True if the client exchange message is valid and a + /// shared key was generated. Otherwise, false. + /// </returns> + bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage); + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs new file mode 100644 index 0000000..cbee1b0 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs @@ -0,0 +1,84 @@ +using System; + +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS cipher suite interface for protection of record payload. + /// </summary> + public interface IRecordProtection : IDisposable + { + /// <summary> + /// Calculate the size of an encrypted plaintext + /// </summary> + /// <param name="dataSize">Size of plaintext in bytes</param> + /// <returns>Size of encrypted ciphertext in bytes</returns> + int GetEncryptedSize(int dataSize); + + /// <summary> + /// Calculate the size of decrypted ciphertext + /// </summary> + /// <param name="dataSize">Size of ciphertext in bytes</param> + /// <returns>Size of decrypted plaintext in bytes</returns> + int GetDecryptedSize(int dataSize); + + /// <summary> + /// Encrypt a plaintext intput with server keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output ciphertext</param> + /// <param name="input">Input plaintext</param> + /// <param name="record">Parent DTLS record</param> + void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Encrypt a plaintext intput with client keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output ciphertext</param> + /// <param name="input">Input plaintext</param> + /// <param name="record">Parent DTLS record</param> + void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Decrypt a ciphertext intput with server keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output plaintext</param> + /// <param name="input">Input ciphertext</param> + /// <param name="record">Parent DTLS record</param> + /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns> + bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record); + + /// <summary> + /// Decrypt a ciphertext intput with client keys + /// + /// Output may overlap with input. + /// </summary> + /// <param name="output">Output plaintext</param> + /// <param name="input">Input ciphertext</param> + /// <param name="record">Parent DTLS record</param> + /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns> + bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record); + } + + /// <summary> + /// Factory to create record protection from cipher suite identifiers + /// </summary> + public sealed class RecordProtectionFactory + { + public static IRecordProtection Create(CipherSuite cipherSuite, ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom) + { + switch (cipherSuite) + { + case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return new Aes128GcmRecordProtection(masterSecret, serverRandom, clientRandom); + + default: + return null; + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs new file mode 100644 index 0000000..76fa132 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Hazel.Dtls +{ + /// <summary> + /// Passthrough record protection implementaion + /// </summary> + public class NullRecordProtection : IRecordProtection + { + public readonly static NullRecordProtection Instance = new NullRecordProtection(); + + public void Dispose() + { + } + + public int GetEncryptedSize(int dataSize) + { + return dataSize; + } + + public int GetDecryptedSize(int dataSize) + { + return dataSize; + } + + public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + } + + public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + } + + public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + return true; + } + + public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record) + { + CopyMaybeOverlappingSpans(output, input); + return true; + } + + private static void CopyMaybeOverlappingSpans(ByteSpan output, ByteSpan input) + { + // Early out if the ranges `output` is equal to `input` + if (output.GetUnderlyingArray() == input.GetUnderlyingArray()) + { + if (output.Offset == input.Offset && output.Length == input.Length) + { + return; + } + } + + input.CopyTo(output); + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs new file mode 100644 index 0000000..1fa7f17 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs @@ -0,0 +1,84 @@ +using System.Text; +using System.Security.Cryptography; + +namespace Hazel.Dtls +{ + /// <summary> + /// Common Psuedorandom Function labels for TLS + /// </summary> + public struct PrfLabel + { + public static readonly ByteSpan MASTER_SECRET = LabelToBytes("master secert"); + public static readonly ByteSpan KEY_EXPANSION = LabelToBytes("key expansion"); + public static readonly ByteSpan CLIENT_FINISHED = LabelToBytes("client finished"); + public static readonly ByteSpan SERVER_FINISHED = LabelToBytes("server finished"); + + /// <summary> + /// Convert a text label to a byte sequence + /// </summary> + public static ByteSpan LabelToBytes(string label) + { + return Encoding.ASCII.GetBytes(label); + } + } + + /// <summary> + /// The P_SHA256 Psuedorandom Function + /// </summary> + public struct PrfSha256 + { + /// <summary> + /// Expand a secret key + /// </summary> + /// <param name="output">Output span. Length determines how much data to generate</param> + /// <param name="key">Original key to expand</param> + /// <param name="label">Label (treated as a salt)</param> + /// <param name="initialSeed">Seed for expansion (treated as a salt)</param> + public static void ExpandSecret(ByteSpan output, ByteSpan key, string label, ByteSpan initialSeed) + { + ExpandSecret(output, key, PrfLabel.LabelToBytes(label), initialSeed); + } + + /// <summary> + /// Expand a secret key + /// </summary> + /// <param name="output">Output span. Length determines how much data to generate</param> + /// <param name="key">Original key to expand</param> + /// <param name="label">Label (treated as a salt)</param> + /// <param name="initialSeed">Seed for expansion (treated as a salt)</param> + public static void ExpandSecret(ByteSpan output, ByteSpan key, ByteSpan label, ByteSpan initialSeed) + { + ByteSpan writer = output; + + byte[] roundSeed = new byte[label.Length + initialSeed.Length]; + label.CopyTo(roundSeed); + initialSeed.CopyTo(roundSeed, label.Length); + + byte[] hashA = roundSeed; + + using (HMACSHA256 hmac = new HMACSHA256(key.ToArray())) + { + byte[] input = new byte[hmac.OutputBlockSize + roundSeed.Length]; + new ByteSpan(roundSeed).CopyTo(input, hmac.OutputBlockSize); + + while (writer.Length > 0) + { + // Update hashA + hashA = hmac.ComputeHash(hashA); + + // generate hash input + new ByteSpan(hashA).CopyTo(input); + + ByteSpan roundOutput = hmac.ComputeHash(input); + if (roundOutput.Length > writer.Length) + { + roundOutput = roundOutput.Slice(0, writer.Length); + } + + roundOutput.CopyTo(writer); + writer = writer.Slice(roundOutput.Length); + } + } + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Record.cs b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs new file mode 100644 index 0000000..23eaa95 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs @@ -0,0 +1,123 @@ +namespace Hazel.Dtls +{ + /// <summary> + /// DTLS version constants + /// </summary> + public enum ProtocolVersion : ushort + { + /// <summary> + /// Use to obfuscate DTLS as regular UDP packets + /// </summary> + UDP = 0, + + /// <summary> + /// DTLS 1.2 + /// </summary> + DTLS1_2 = 0xFEFD, + } + + /// <summary> + /// DTLS record content type + /// </summary> + public enum ContentType : byte + { + ChangeCipherSpec = 20, + Alert = 21, + Handshake = 22, + ApplicationData = 23, + } + + /// <summary> + /// Encode/decode DTLS record header + /// </summary> + public struct Record + { + public ContentType ContentType; + public ProtocolVersion ProtocolVersion; + public ushort Epoch; + public ulong SequenceNumber; + public ushort Length; + + public const int Size = 13; + + /// <summary> + /// Parse a DTLS record from wire format + /// </summary> + /// <returns>True if we successfully parse the record header. Otherwise false</returns> + public static bool Parse(out Record record, ProtocolVersion? expectedProtocolVersion, ByteSpan span) + { + record = new Record(); + + if (span.Length < Size) + { + return false; + } + + record.ContentType = (ContentType)span[0]; + record.ProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(1); + record.Epoch = span.ReadBigEndian16(3); + record.SequenceNumber = span.ReadBigEndian48(5); + record.Length = span.ReadBigEndian16(11); + + if (expectedProtocolVersion.HasValue && record.ProtocolVersion != expectedProtocolVersion.Value) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode a DTLS record to wire format + /// </summary> + public void Encode(ByteSpan span) + { + span[0] = (byte)this.ContentType; + span.WriteBigEndian16((ushort)this.ProtocolVersion, 1); + span.WriteBigEndian16(this.Epoch, 3); + span.WriteBigEndian48(this.SequenceNumber, 5); + span.WriteBigEndian16(this.Length, 11); + } + } + + public struct ChangeCipherSpec + { + public const int Size = 1; + + enum Value : byte + { + ChangeCipherSpec = 1, + } + + /// <summary> + /// Parse a ChangeCipherSpec record from wire format + /// </summary> + /// <returns> + /// True if we successfully parse the ChangeCipherSpec + /// record. Otherwise, false. + /// </returns> + public static bool Parse(ByteSpan span) + { + if (span.Length != 1) + { + return false; + } + + Value value = (Value)span[0]; + if (value != Value.ChangeCipherSpec) + { + return false; + } + + return true; + } + + /// <summary> + /// Encode a ChangeCipherSpec record to wire format + /// </summary> + public static void Encode(ByteSpan span) + { + span[0] = (byte)Value.ChangeCipherSpec; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs new file mode 100644 index 0000000..38da061 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs @@ -0,0 +1,159 @@ +using System; +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Threading; + +namespace Hazel.Dtls +{ + internal class ThreadedHmacHelper : IDisposable + { + private class ThreadHmacs + { + public HMAC currentHmac; + public HMAC previousHmac; + public HMAC hmacToDispose; + } + + private static readonly int CookieHmacRotationTimeout = (int)TimeSpan.FromHours(1.0).TotalMilliseconds; + + private readonly ILogger logger; + private readonly ConcurrentDictionary<int, ThreadHmacs> hmacs; + private Timer rotateKeyTimer; + private RandomNumberGenerator cryptoRandom; + private byte[] currentHmacKey; + + public ThreadedHmacHelper(ILogger logger) + { + this.hmacs = new ConcurrentDictionary<int, ThreadHmacs>(); + this.rotateKeyTimer = new Timer(RotateKeys, null, CookieHmacRotationTimeout, CookieHmacRotationTimeout); + this.cryptoRandom = RandomNumberGenerator.Create(); + + this.logger = logger; + SetHmacKey(); + } + + /// <summary> + /// [ThreadSafe] Get the current cookie hmac for the current thread. + /// </summary> + public HMAC GetCurrentCookieHmacsForThread() + { + return GetHmacsForThread().currentHmac; + } + + /// <summary> + /// [ThreadSafe] Get the previous cookie hmac for the current thread. + /// </summary> + public HMAC GetPreviousCookieHmacsForThread() + { + return GetHmacsForThread().previousHmac; + } + + public void Dispose() + { + ManualResetEvent signalRotateKeyTimerEnded = new ManualResetEvent(false); + this.rotateKeyTimer.Dispose(signalRotateKeyTimerEnded); + signalRotateKeyTimerEnded.WaitOne(); + signalRotateKeyTimerEnded.Dispose(); + signalRotateKeyTimerEnded = null; + this.rotateKeyTimer = null; + + this.cryptoRandom.Dispose(); + this.cryptoRandom = null; + + foreach (var threadIdToHmac in this.hmacs) + { + ThreadHmacs threadHmacs = threadIdToHmac.Value; + threadHmacs.currentHmac?.Dispose(); + threadHmacs.currentHmac = null; + threadHmacs.previousHmac?.Dispose(); + threadHmacs.previousHmac = null; + threadHmacs.hmacToDispose?.Dispose(); + threadHmacs.hmacToDispose = null; + } + + this.hmacs.Clear(); + } + + private ThreadHmacs GetHmacsForThread() + { + int threadId = Thread.CurrentThread.ManagedThreadId; + + if (!this.hmacs.TryGetValue(threadId, out ThreadHmacs threadHmacs)) + { + threadHmacs = CreateNewThreadHmacs(); + + if (!this.hmacs.TryAdd(threadId, threadHmacs)) + { + this.logger.WriteError($"Cannot add threadHmacs for thread {threadId} during GetHmacsForThread! Should never happen!"); + } + } + + return threadHmacs; + } + + /// <summary> + /// Rotates the hmacs of all active threads + /// </summary> + private void RotateKeys(object _) + { + SetHmacKey(); + + foreach (var threadIds in this.hmacs) + { + RotateKey(threadIds.Key); + } + } + + /// <summary> + /// Rotate hmacs of single thread + /// </summary> + /// <param name="threadId">Managed thread Id of thread calling this method.</param> + private void RotateKey(int threadId) + { + ThreadHmacs threadHmacs; + + if (!this.hmacs.TryGetValue(threadId, out threadHmacs)) + { + this.logger.WriteError($"Cannot find thread {threadId} in hmacs during rotation! Should never happen!"); + return; + } + + // No thread should still have a reference to hmacToDispose, which should now have a lifetime of > 1 hour + threadHmacs.hmacToDispose?.Dispose(); + threadHmacs.hmacToDispose = threadHmacs.previousHmac; + threadHmacs.previousHmac = threadHmacs.currentHmac; + threadHmacs.currentHmac = CreateNewCookieHMAC(); + } + + private ThreadHmacs CreateNewThreadHmacs() + { + return new ThreadHmacs + { + previousHmac = CreateNewCookieHMAC(), + currentHmac = CreateNewCookieHMAC() + }; + } + + /// <summary> + /// Create a new cookie HMAC signer + /// </summary> + private HMAC CreateNewCookieHMAC() + { + const string HMACProvider = "System.Security.Cryptography.HMACSHA1"; + HMAC hmac = HMAC.Create(HMACProvider); + hmac.Key = this.currentHmacKey; + return hmac; + } + + /// <summary> + /// Creates a new cryptographically secure random Hmac key + /// </summary> + private void SetHmacKey() + { + // MSDN recommends 64 bytes key for HMACSHA-1 + byte[] newKey = new byte[64]; + this.cryptoRandom.GetBytes(newKey); + this.currentHmacKey = newKey; + } + } +} diff --git a/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs new file mode 100644 index 0000000..f567252 --- /dev/null +++ b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs @@ -0,0 +1,202 @@ +using Hazel.Crypto; +using System; +using System.Diagnostics; +using System.Security.Cryptography; + +namespace Hazel.Dtls +{ + /// <summary> + /// ECDHE_RSA_*_256 cipher suite + /// </summary> + public class X25519EcdheRsaSha256 : IHandshakeCipherSuite + { + private readonly ByteSpan privateAgreementKey; + private SHA256 sha256 = SHA256.Create(); + + /// <summary> + /// Create a new instance of the x25519 key exchange + /// </summary> + /// <param name="random">Random data source</param> + public X25519EcdheRsaSha256(RandomNumberGenerator random) + { + byte[] buffer = new byte[X25519.KeySize]; + random.GetBytes(buffer); + this.privateAgreementKey = buffer; + } + + /// <inheritdoc /> + public void Dispose() + { + this.sha256?.Dispose(); + this.sha256 = null; + } + + /// <inheritdoc /> + public int SharedKeySize() + { + return X25519.KeySize; + } + + /// <summary> + /// Calculate the server message size given an RSA key size + /// </summary> + /// <param name="keySize"> + /// Size of the private key (in bits) + /// </param> + /// <returns> + /// Size of the ServerKeyExchange message in bytes + /// </returns> + private static int CalculateServerMessageSize(int keySize) + { + int signatureSize = keySize / 8; + + return 0 + + 1 // ECCurveType ServerKeyExchange.params.curve_params.curve_type + + 2 // NamedCurve ServerKeyExchange.params.curve_params.namedcurve + + 1 + X25519.KeySize // ECPoint ServerKeyExchange.params.public + + 1 // HashAlgorithm ServerKeyExchange.algorithm.hash + + 1 // SignatureAlgorithm ServerKeyExchange.signed_params.algorithm.signature + + 2 // ServerKeyExchange.signed_params.size + + signatureSize // ServerKeyExchange.signed_params.opaque + ; + } + + /// <inheritdoc /> + public int CalculateServerMessageSize(object privateKey) + { + RSA rsaPrivateKey = privateKey as RSA; + if (rsaPrivateKey == null) + { + throw new ArgumentException("Invalid private key", nameof(privateKey)); + } + + return CalculateServerMessageSize(rsaPrivateKey.KeySize); + } + + /// <inheritdoc /> + public void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey) + { + RSA rsaPrivateKey = privateKey as RSA; + if (rsaPrivateKey == null) + { + throw new ArgumentException("Invalid private key", nameof(privateKey)); + } + + output[0] = (byte)ECCurveType.NamedCurve; + output.WriteBigEndian16((ushort)NamedCurve.x25519, 1); + output[3] = (byte)X25519.KeySize; + X25519.Func(output.Slice(4, X25519.KeySize), this.privateAgreementKey); + + // Hash the key parameters + byte[] paramterDigest = this.sha256.ComputeHash(output.GetUnderlyingArray(), output.Offset, 4 + X25519.KeySize); + + // Sign the paramter digest + RSAPKCS1SignatureFormatter signer = new RSAPKCS1SignatureFormatter(rsaPrivateKey); + signer.SetHashAlgorithm("SHA256"); + ByteSpan signature = signer.CreateSignature(paramterDigest); + + Debug.Assert(signature.Length == rsaPrivateKey.KeySize/8); + output[4 + X25519.KeySize] = (byte)HashAlgorithm.Sha256; + output[5 + X25519.KeySize] = (byte)SignatureAlgorithm.RSA; + output.Slice(6+X25519.KeySize).WriteBigEndian16((ushort)signature.Length); + signature.CopyTo(output.Slice(8+X25519.KeySize)); + } + + /// <inheritdoc /> + public bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey) + { + RSA rsaPublicKey = publicKey as RSA; + if (rsaPublicKey == null) + { + return false; + } + else if (output.Length != X25519.KeySize) + { + return false; + } + + // Verify message is compatible with this cipher suite + if (serverKeyExchangeMessage.Length != CalculateServerMessageSize(rsaPublicKey.KeySize)) + { + return false; + } + else if (serverKeyExchangeMessage[0] != (byte)ECCurveType.NamedCurve) + { + return false; + } + else if (serverKeyExchangeMessage.ReadBigEndian16(1) != (ushort)NamedCurve.x25519) + { + return false; + } + else if (serverKeyExchangeMessage[3] != X25519.KeySize) + { + return false; + } + else if (serverKeyExchangeMessage[4 + X25519.KeySize] != (byte)HashAlgorithm.Sha256) + { + return false; + } + else if (serverKeyExchangeMessage[5 + X25519.KeySize] != (byte)SignatureAlgorithm.RSA) + { + return false; + } + + ByteSpan keyParameters = serverKeyExchangeMessage.Slice(0, 4+X25519.KeySize); + ByteSpan othersPublicKey = keyParameters.Slice(4); + ushort signatureSize = serverKeyExchangeMessage.ReadBigEndian16(6 + X25519.KeySize); + ByteSpan signature = serverKeyExchangeMessage.Slice(4+keyParameters.Length); + + if (signatureSize != signature.Length) + { + return false; + } + + // Hash the key parameters + byte[] parameterDigest = this.sha256.ComputeHash(keyParameters.GetUnderlyingArray(), keyParameters.Offset, keyParameters.Length); + + // Verify the signature + RSAPKCS1SignatureDeformatter verifier = new RSAPKCS1SignatureDeformatter(rsaPublicKey); + verifier.SetHashAlgorithm("SHA256"); + if (!verifier.VerifySignature(parameterDigest, signature.ToArray())) + { + return false; + } + + // Signature has been validated, generate the shared key + return X25519.Func(output, this.privateAgreementKey, othersPublicKey); + } + + private static int ClientMessageSize = 0 + + 1 + X25519.KeySize // ECPoint ClientKeyExchange.ecdh_Yc + ; + + /// <inheritdoc /> + public int CalculateClientMessageSize() + { + return ClientMessageSize; + } + + /// <inheritdoc /> + public void EncodeClientKeyExchangeMessage(ByteSpan output) + { + output[0] = (byte)X25519.KeySize; + X25519.Func(output.Slice(1, X25519.KeySize), this.privateAgreementKey); + } + + /// <inheritdoc /> + public bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage) + { + if (clientKeyExchangeMessage.Length != ClientMessageSize) + { + return false; + } + else if (clientKeyExchangeMessage[0] != (byte)X25519.KeySize) + { + return false; + } + + ByteSpan othersPublicKey = clientKeyExchangeMessage.Slice(1); + return X25519.Func(output, this.privateAgreementKey, othersPublicKey); + } + } +} |