using Hazel.Crypto; using Hazel.Udp; using System; using System.Collections.Generic; using System.Diagnostics; using System.Net; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; namespace Hazel.Dtls { /// /// Connects to a UDP-DTLS server /// /// public class DtlsUnityConnection : UnityUdpClientConnection { /// /// Current state of the handshake sequence /// enum HandshakeState { Initializing, ExpectingServerHello, ExpectingCertificate, ExpectingServerKeyExchange, ExpectingServerHelloDone, ExpectingChangeCipherSpec, ExpectingFinished, Established, } /// /// State data for the current epoch /// struct CurrentEpoch { public ulong NextOutgoingSequence; public ulong NextExpectedSequence; public ulong PreviousSequenceWindowBitmask; public IRecordProtection RecordProtection; } struct FragmentRange { public int Offset; public int Length; } /// /// State data for the next epoch /// struct NextEpoch { public ushort Epoch; public HandshakeState State; public ulong NextOutgoingSequence; public DateTime NegotiationStartTime; public DateTime NextPacketResendTime; public int PacketResendCount; public CipherSuite SelectedCipherSuite; public IRecordProtection RecordProtection; public IHandshakeCipherSuite Handshake; public ByteSpan Cookie; public Sha256Stream VerificationStream; public RSA ServerPublicKey; public ByteSpan ClientRandom; public ByteSpan ServerRandom; public ByteSpan MasterSecret; public ByteSpan ServerVerification; public List CertificateFragments; public ByteSpan CertificatePayload; } struct QueuedAppData { public byte[] Bytes; public byte SendOption; public Action AckCallback; } private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP; internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion; private readonly object syncRoot = new object(); private readonly RandomNumberGenerator random = RandomNumberGenerator.Create(); private ushort epoch; private CurrentEpoch currentEpoch; private NextEpoch nextEpoch; private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200); private readonly Queue queuedApplicationData = new Queue(); private X509Certificate2Collection serverCertificates = new X509Certificate2Collection(); public bool HandshakeComplete { get { lock (this.syncRoot) { return this.nextEpoch.State == HandshakeState.Established; } } } /// /// Create a new instance of the DTLS connection /// /// public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode) { this.nextEpoch.ServerRandom = new byte[Random.Size]; this.nextEpoch.ClientRandom = new byte[Random.Size]; this.nextEpoch.ServerVerification = new byte[Finished.Size]; this.nextEpoch.CertificateFragments = new List(); this.ResetConnectionState(); } /// protected override void Dispose(bool disposing) { base.Dispose(disposing); lock (this.syncRoot) { this.ResetConnectionState(); } } /// /// Set the list of valid server certificates /// /// /// List of certificates of authentic servers /// public void SetValidServerCertificates(X509Certificate2Collection certificateCollection) { lock (this.syncRoot) { foreach (X509Certificate2 certificate in certificateCollection) { if (!(certificate.PublicKey.Key is RSA)) { throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection)); } } this.serverCertificates = certificateCollection; } } /// /// Set the packet resend timer for handshake messages /// public void SetHandshakeResendTimeout(TimeSpan timeout) { lock (this.syncRoot) { this.handshakeResendTimeout = timeout; } } /// /// Reset existing connection state /// private void ResetConnectionState() { this.currentEpoch.NextOutgoingSequence = 1; this.currentEpoch.NextExpectedSequence = 1; this.currentEpoch.PreviousSequenceWindowBitmask = 0; this.currentEpoch.RecordProtection?.Dispose(); this.currentEpoch.RecordProtection = NullRecordProtection.Instance; this.nextEpoch.Epoch = 1; this.nextEpoch.State = HandshakeState.Initializing; this.nextEpoch.NextOutgoingSequence = 1; this.nextEpoch.NegotiationStartTime = DateTime.MinValue; this.nextEpoch.NextPacketResendTime = DateTime.MinValue; this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; this.nextEpoch.RecordProtection?.Dispose(); this.nextEpoch.RecordProtection = null; this.nextEpoch.Handshake?.Dispose(); this.nextEpoch.Handshake = null; this.nextEpoch.Cookie = ByteSpan.Empty; this.nextEpoch.VerificationStream?.Dispose(); this.nextEpoch.VerificationStream = new Sha256Stream(); this.nextEpoch.ServerPublicKey = null; this.nextEpoch.ServerRandom.SecureClear(); this.nextEpoch.ClientRandom.SecureClear(); this.nextEpoch.MasterSecret.SecureClear(); this.nextEpoch.ServerVerification.SecureClear(); this.nextEpoch.CertificateFragments.Clear(); this.nextEpoch.CertificatePayload = ByteSpan.Empty; this.epoch = 0; while (this.queuedApplicationData.TryDequeue(out _)) ; } /// /// Abort the existing connection and restart the process /// protected override void RestartConnection() { lock (this.syncRoot) { this.ResetConnectionState(); this.nextEpoch.ClientRandom.FillWithRandom(this.random); this.SendClientHello(isRetransmit: false); } base.RestartConnection(); } /// protected override void ResendPacketsIfNeeded() { lock (this.syncRoot) { // Check if we need to resend handshake message if (this.nextEpoch.State != HandshakeState.Established) { DateTime now = DateTime.UtcNow; if (now >= this.nextEpoch.NextPacketResendTime) { double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds; this.nextEpoch.PacketResendCount++; if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit) || negotiationDurationMs > this.DisconnectTimeoutMs) { this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms)."); } else { switch (this.nextEpoch.State) { case HandshakeState.ExpectingServerHello: case HandshakeState.ExpectingCertificate: case HandshakeState.ExpectingServerKeyExchange: case HandshakeState.ExpectingServerHelloDone: this.SendClientHello(isRetransmit: true); break; case HandshakeState.ExpectingChangeCipherSpec: case HandshakeState.ExpectingFinished: this.SendClientKeyExchangeFlight(isRetransmit: true); break; case HandshakeState.Established: default: break; } } } } } base.ResendPacketsIfNeeded(); } /// /// Flush any queued application data packets /// private void FlushQueuedApplicationData() { while (this.queuedApplicationData.TryDequeue(out var queuedData)) { base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback); } } /// /// Request from the application to write data to the DTLS /// stream. If appropriate, returns a byte span to send to /// the wire. /// /// Plaintext bytes to write /// Length of the bytes to write /// /// Encrypted data to put on the wire if appropriate, /// otherwise an empty span /// private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length) { lock (this.syncRoot) { Record outgoinRecord = new Record(); outgoinRecord.ContentType = ContentType.ApplicationData; outgoinRecord.ProtocolVersion = DtlsVersion; outgoinRecord.Epoch = this.epoch; outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length); ++this.currentEpoch.NextOutgoingSequence; // Encode the record to wire format ByteSpan packet = new byte[Record.Size + outgoinRecord.Length]; ByteSpan writer = packet; outgoinRecord.Encode(writer); writer = writer.Slice(Record.Size); new ByteSpan(bytes, 0, length).CopyTo(writer); // Protect the record this.currentEpoch.RecordProtection.EncryptClientPlaintext( packet.Slice(Record.Size, outgoinRecord.Length), packet.Slice(Record.Size, length), ref outgoinRecord ); return packet; } } protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null) { lock (this.syncRoot) { // If we're negotiating a new epoch, queue data if (this.nextEpoch.State != HandshakeState.Established) { this.queuedApplicationData.Enqueue(new QueuedAppData { Bytes = data, SendOption = sendOption, AckCallback = ackCallback }); return; } } base.HandleSend(data, sendOption, ackCallback); } /// protected override void WriteBytesToConnection(byte[] bytes, int length) { ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); if (wireData.Length > 0) { Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length); } } /// protected override void WriteBytesToConnectionSync(byte[] bytes, int length) { ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length); if (wireData.Length > 0) { Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset"); base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length); } } /// protected internal override void HandleReceive(MessageReader reader, int bytesReceived) { ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining); lock (this.syncRoot) { this.HandleReceive(message); } reader.Recycle(); } /// /// Handle an incoming datagram /// /// Bytes of the datagram private void HandleReceive(ByteSpan span) { // Each incoming packet may contain multiple DTLS // records while (span.Length > 0) { Record record; if (!Record.Parse(out record, DtlsVersion, span)) { this.logger.WriteError("Dropping malformed record"); return; } span = span.Slice(Record.Size); if (span.Length < record.Length) { this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})"); return; } ByteSpan recordPayload = span.Slice(0, record.Length); span = span.Slice(record.Length); // Early out and drop ApplicationData records if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established) { this.logger.WriteError("Dropping ApplicationData record. Cannot process yet"); continue; } // Drop records from a different epoch if (record.Epoch != this.epoch) { this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})"); continue; } // Prevent replay attacks by dropping records // we've already processed int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1); ulong windowMask = 1ul << windowIndex; if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence) { if (windowIndex >= 64) { this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})"); continue; } if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0) { this.logger.WriteWarning("Dropping duplicate record"); continue; } } // Verify record authenticity int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length); ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize); if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record)) { this.logger.WriteError("Dropping non-authentic record"); return; } recordPayload = decryptedPayload; // Update out sequence number bookkeeping if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence) { int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence); this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift; this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1; } else { this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask; } // This is handy for debugging, but too verbose even for verbose. // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})"); switch (record.ContentType) { case ContentType.ChangeCipherSpec: if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec) { this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})"); break; } else if (this.nextEpoch.RecordProtection == null) { ///NOTE(mendsley): This _should_ not /// happen on a well-formed client. Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?"); break; } if (!ChangeCipherSpec.Parse(recordPayload)) { this.logger.WriteError("Dropping malformed ChangeCipherSpec message"); break; } // Migrate to the next epoch this.epoch = this.nextEpoch.Epoch; this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection; this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence; this.currentEpoch.NextExpectedSequence = 1; this.currentEpoch.PreviousSequenceWindowBitmask = 0; this.nextEpoch.State = HandshakeState.ExpectingFinished; this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL; this.nextEpoch.RecordProtection = null; this.nextEpoch.Handshake?.Dispose(); this.nextEpoch.Cookie = ByteSpan.Empty; this.nextEpoch.VerificationStream.Reset(); this.nextEpoch.ServerPublicKey = null; this.nextEpoch.ServerRandom.SecureClear(); this.nextEpoch.ClientRandom.SecureClear(); this.nextEpoch.MasterSecret.SecureClear(); break; case ContentType.Alert: this.logger.WriteError("Dropping unsupported alert record"); continue; case ContentType.Handshake: if (!ProcessHandshake(ref record, recordPayload)) { return; } break; case ContentType.ApplicationData: // Forward data to the application MessageReader reader = MessageReader.GetSized(recordPayload.Length); reader.Length = recordPayload.Length; recordPayload.CopyTo(reader.Buffer); base.HandleReceive(reader, recordPayload.Length); break; } } } /// /// Process an incoming Handshake protocol message /// /// Parent record /// Record payload /// /// True if further processing of the underlying datagram /// should be continues. Otherwise, false. /// private bool ProcessHandshake(ref Record record, ByteSpan message) { // Each record may have multiple Handshake messages while (message.Length > 0) { ByteSpan originalPayload = message; Handshake handshake; if (!Handshake.Parse(out handshake, message)) { this.logger.WriteError("Dropping malformed handshake message"); return false; } message = message.Slice(Handshake.Size); // Check for fragmented messages if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length) { // We only support fragmentation on Certificate messages if (handshake.MessageType != HandshakeType.Certificate) { this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})"); continue; } if (message.Length < handshake.FragmentLength) { this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})"); return false; } originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength)); message = message.Slice((int)handshake.FragmentLength); } else { if (message.Length < handshake.Length) { this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})"); return false; } originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length)); message = message.Slice((int)handshake.Length); } ByteSpan payload = originalPayload.Slice(Handshake.Size); #if DEBUG this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})"); #endif switch (handshake.MessageType) { case HandshakeType.HelloVerifyRequest: if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) { this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})"); continue; } else if (handshake.MessageSequence != 0) { this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})"); continue; } HelloVerifyRequest helloVerifyRequest; if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload)) { this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message"); continue; } // If the cookie differs, save it and restart the handshake if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1) { this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message"); continue; } this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length]; helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie); this.nextEpoch.ClientRandom.FillWithRandom(this.random); // We don't need to resend here. We already have the cookie so we already sent it once. this.SendClientHello(isRetransmit: false); break; case HandshakeType.ServerHello: if (this.nextEpoch.State != HandshakeState.ExpectingServerHello) { this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})"); continue; } else if (handshake.MessageSequence != 1) { this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})"); continue; } ServerHello serverHello; if (!ServerHello.Parse(out serverHello, payload)) { this.logger.WriteError("Dropping malformed ServerHello message"); continue; } switch (serverHello.CipherSuite) { case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random); break; default: this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})"); continue; } // Save server parameters this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite; serverHello.Random.CopyTo(this.nextEpoch.ServerRandom); this.nextEpoch.State = HandshakeState.ExpectingCertificate; this.nextEpoch.CertificateFragments.Clear(); this.nextEpoch.CertificatePayload = ByteSpan.Empty; #if DEBUG this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}"); #endif // Append ServerHelllo message to the verification stream this.nextEpoch.VerificationStream.AddData(originalPayload); break; case HandshakeType.Certificate: if (this.nextEpoch.State != HandshakeState.ExpectingCertificate) { this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})"); continue; } else if (handshake.MessageSequence != 2) { this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})"); continue; } // If this is a fragmented message if (handshake.FragmentLength != handshake.Length) { if (this.nextEpoch.CertificatePayload.Length != handshake.Length) { this.nextEpoch.CertificatePayload = new byte[handshake.Length]; this.nextEpoch.CertificateFragments.Clear(); } // Should we add this fragment? // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt? // So let's just ignore that and assume that the sender always wants to send the same fragments. if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength)) { continue; } payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength)); this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength }); this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => { return lhs.Offset.CompareTo(rhs.Offset); }); // Have we completed the message? int currentOffset = 0; bool valid = true; foreach (FragmentRange range in this.nextEpoch.CertificateFragments) { if (range.Offset != currentOffset) { valid = false; break; } currentOffset += range.Length; } if (currentOffset != this.nextEpoch.CertificatePayload.Length) { valid = false; } // Still waiting on more fragments? if (!valid) { continue; } // Replace the message payload, and continue this.nextEpoch.CertificateFragments.Clear(); payload = this.nextEpoch.CertificatePayload; } X509Certificate2 certificate; if (!Certificate.Parse(out certificate, payload)) { this.logger.WriteError("Dropping malformed Certificate message"); continue; } // Verify the certificate is authenticate if (!this.serverCertificates.Contains(certificate)) { this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic"); continue; } RSA publicKey = certificate.PublicKey.Key as RSA; if (publicKey == null) { this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed"); continue; } // Add the final Certificate message to the verification stream Handshake fullCertificateHandhake = handshake; fullCertificateHandhake.FragmentOffset = 0; fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length; ByteSpan serializedCertificateHandshake = new byte[Handshake.Size]; fullCertificateHandhake.Encode(serializedCertificateHandshake); this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake); this.nextEpoch.VerificationStream.AddData(payload); this.nextEpoch.ServerPublicKey = publicKey; this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange; break; case HandshakeType.ServerKeyExchange: if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange) { this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})"); continue; } else if (this.nextEpoch.ServerPublicKey == null) { ///NOTE(mendsley): This _should_ not /// happen on a well-formed client Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?"); this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key"); continue; } else if (this.nextEpoch.Handshake == null) { ///NOTE(mendsley): This _should_ not /// happen on a well-formed client Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?"); this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface"); continue; } else if (handshake.MessageSequence != 3) { this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})"); continue; } ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()]; if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey)) { this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage"); return false; } // Generate the session master secret ByteSpan randomSeed = new byte[2 * Random.Size]; this.nextEpoch.ClientRandom.CopyTo(randomSeed); this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size)); const int MasterSecretSize = 48; ByteSpan masterSecret = new byte[MasterSecretSize]; PrfSha256.ExpandSecret( masterSecret , sharedSecret , PrfLabel.MASTER_SECRET , randomSeed ); // Create record protection for the upcoming epoch switch (this.nextEpoch.SelectedCipherSuite) { case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection( masterSecret , this.nextEpoch.ServerRandom , this.nextEpoch.ClientRandom ); break; default: ///NOTE(mendsley): this _should_ not /// happen on a well-formed client. Debug.Assert(false, "SeverHello processing already approved this ciphersuite"); this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection"); return false; } this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone; this.nextEpoch.MasterSecret = masterSecret; // Append ServerKeyExchange to the verification stream this.nextEpoch.VerificationStream.AddData(originalPayload); break; case HandshakeType.ServerHelloDone: if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone) { this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})"); continue; } else if (handshake.MessageSequence != 4) { this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})"); continue; } this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; // Append ServerHelloDone to the verification stream this.nextEpoch.VerificationStream.AddData(originalPayload); this.SendClientKeyExchangeFlight(isRetransmit: false); break; case HandshakeType.Finished: if (this.nextEpoch.State != HandshakeState.ExpectingFinished) { this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})"); continue; } else if (payload.Length != Finished.Size) { this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})"); continue; } else if (handshake.MessageSequence != 7) { this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})"); continue; } // Verify the digest from the server if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification)) { this.logger.WriteError("Dropping non-verified Finished handshake message"); return false; } ++this.nextEpoch.Epoch; this.nextEpoch.State = HandshakeState.Established; this.nextEpoch.NegotiationStartTime = DateTime.MinValue; this.nextEpoch.NextPacketResendTime = DateTime.MinValue; this.nextEpoch.ServerVerification.SecureClear(); this.nextEpoch.MasterSecret.SecureClear(); this.FlushQueuedApplicationData(); break; // Drop messages we do not support case HandshakeType.CertificateRequest: case HandshakeType.HelloRequest: this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})"); break; // Drop messages that originate from the client case HandshakeType.ClientHello: case HandshakeType.ClientKeyExchange: case HandshakeType.CertificateVerify: this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})"); break; } } return true; } private bool IsFragmentOverlapping(List fragments, uint newOffset, uint newLength) { foreach (var frag in fragments) { // New fragment overlaps an existing one if (newOffset <= frag.Offset && frag.Offset < newOffset + newLength) { return true; } // Existing fragment overlaps this new one if (frag.Offset <= newOffset && newOffset < frag.Offset + frag.Length) { return true; } } return false; } /// /// Send (resend) a ClientHello message to the server /// protected virtual void SendClientHello(bool isRetransmit) { #if DEBUG var verb = isRetransmit ? "Resending" : "Sending"; this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}"); #endif // Describe our ClientHello flight ClientHello clientHello = new ClientHello(); clientHello.ClientProtocolVersion = DtlsVersion; clientHello.Random = this.nextEpoch.ClientRandom; clientHello.Cookie = this.nextEpoch.Cookie; clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion); clientHello.CipherSuites = new byte[2]; clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); clientHello.SupportedCurves = new byte[2]; clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); Handshake handshake = new Handshake(); handshake.MessageType = HandshakeType.ClientHello; handshake.Length = (uint)clientHello.CalculateSize(); handshake.MessageSequence = 0; handshake.FragmentOffset = 0; handshake.FragmentLength = handshake.Length; // Describe the record int plaintextLength = (int)(Handshake.Size + handshake.Length); Record outgoingRecord = new Record(); outgoingRecord.ContentType = ContentType.Handshake; outgoingRecord.ProtocolVersion = DtlsVersion; outgoingRecord.Epoch = this.epoch; outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); ++this.currentEpoch.NextOutgoingSequence; // Convert the record to wire format ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; ByteSpan writer = packet; outgoingRecord.Encode(packet); writer = writer.Slice(Record.Size); handshake.Encode(writer); writer = writer.Slice(Handshake.Size); clientHello.Encode(writer); // If this is our first valid attempt at contacting the server: // - Reset our verification stream // - Write ClientHello to the verification stream // - We next expect a ServerHello // // ClientHello+Cookie triggers many sequential packets in response // It's important to make forward progress as the packets may be reordered in-flight // But with enough resends, we will read them all in an appropriate order if (!isRetransmit) { this.nextEpoch.VerificationStream.Reset(); this.nextEpoch.VerificationStream.AddData( packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length) ); this.nextEpoch.State = HandshakeState.ExpectingServerHello; } // Protect the record this.currentEpoch.RecordProtection.EncryptClientPlaintext( packet.Slice(Record.Size, outgoingRecord.Length), packet.Slice(Record.Size, plaintextLength), ref outgoingRecord ); if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); } protected void Test_SendClientHello(Func encodeCallback) { // Reset our verification stream this.nextEpoch.VerificationStream.Reset(); // Describe our ClientHello flight ClientHello clientHello = new ClientHello(); clientHello.Random = this.nextEpoch.ClientRandom; clientHello.Cookie = this.nextEpoch.Cookie; clientHello.CipherSuites = new byte[2]; clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256); clientHello.SupportedCurves = new byte[2]; clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519); Handshake handshake = new Handshake(); handshake.MessageType = HandshakeType.ClientHello; handshake.Length = (uint)clientHello.CalculateSize(); handshake.MessageSequence = 0; handshake.FragmentOffset = 0; handshake.FragmentLength = handshake.Length; // Describe the record int plaintextLength = (int)(Handshake.Size + handshake.Length); Record outgoingRecord = new Record(); outgoingRecord.ContentType = ContentType.Handshake; outgoingRecord.ProtocolVersion = DtlsVersion; outgoingRecord.Epoch = this.epoch; outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength); ++this.currentEpoch.NextOutgoingSequence; // Convert the record to wire format ByteSpan packet = new byte[Record.Size + outgoingRecord.Length]; ByteSpan writer = packet; outgoingRecord.Encode(packet); writer = writer.Slice(Record.Size); handshake.Encode(writer); writer = writer.Slice(Handshake.Size); writer = encodeCallback(clientHello, writer); // Write ClientHello to the verification stream this.nextEpoch.VerificationStream.AddData( packet.Slice( Record.Size , Handshake.Size + (int)handshake.Length ) ); // Protect the record this.currentEpoch.RecordProtection.EncryptClientPlaintext( packet.Slice(Record.Size, outgoingRecord.Length), packet.Slice(Record.Size, plaintextLength), ref outgoingRecord ); this.nextEpoch.State = HandshakeState.ExpectingServerHello; if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow; this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); } /// /// Send (resend) the ClientKeyExchange flight /// /// /// True if this is a retransmit of the flight. Otherwise, /// false /// protected virtual void SendClientKeyExchangeFlight(bool isRetransmit) { #if DEBUG var verb = isRetransmit ? "Resending" : "Sending"; this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}"); #endif if (this.nextEpoch.State == HandshakeState.Established) { return; } // Describe our flight Handshake keyExchangeHandshake = new Handshake(); keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange; keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize(); keyExchangeHandshake.MessageSequence = 5; keyExchangeHandshake.FragmentOffset = 0; keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length; Record keyExchangeRecord = new Record(); keyExchangeRecord.ContentType = ContentType.Handshake; keyExchangeRecord.ProtocolVersion = DtlsVersion; keyExchangeRecord.Epoch = this.epoch; keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length); ++this.currentEpoch.NextOutgoingSequence; Record changeCipherSpecRecord = new Record(); changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec; changeCipherSpecRecord.ProtocolVersion = DtlsVersion; changeCipherSpecRecord.Epoch = this.epoch; changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence; changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size); ++this.currentEpoch.NextOutgoingSequence; Handshake finishedHandshake = new Handshake(); finishedHandshake.MessageType = HandshakeType.Finished; finishedHandshake.Length = Finished.Size; finishedHandshake.MessageSequence = 6; finishedHandshake.FragmentOffset = 0; finishedHandshake.FragmentLength = finishedHandshake.Length; Record finishedRecord = new Record(); finishedRecord.ContentType = ContentType.Handshake; finishedRecord.ProtocolVersion = DtlsVersion; finishedRecord.Epoch = this.nextEpoch.Epoch; finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence; finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length); ++this.nextEpoch.NextOutgoingSequence; // Encode flight to wire format int packetLength = 0 + Record.Size + keyExchangeRecord.Length + Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length; ; ByteSpan packet = new byte[packetLength]; ByteSpan writer = packet; keyExchangeRecord.Encode(writer); writer = writer.Slice(Record.Size); keyExchangeHandshake.Encode(writer); writer = writer.Slice(Handshake.Size); this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer); ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length); writer = startOfChangeCipherSpecRecord; changeCipherSpecRecord.Encode(writer); writer = writer.Slice(Record.Size); ChangeCipherSpec.Encode(writer); writer = writer.Slice(ChangeCipherSpec.Size); ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length); writer = startOfFinishedRecord; finishedRecord.Encode(writer); writer = writer.Slice(Record.Size); finishedHandshake.Encode(writer); writer = writer.Slice(Handshake.Size); // Interject here to writer our client key exchange // message into the verification stream if (!isRetransmit) { this.nextEpoch.VerificationStream.AddData( packet.Slice( Record.Size , Handshake.Size + (int)keyExchangeHandshake.Length ) ); } // Calculate the hash of the verification stream ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize]; this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash); // Expand our master secret into Finished digests for the client and server PrfSha256.ExpandSecret( this.nextEpoch.ServerVerification , this.nextEpoch.MasterSecret , PrfLabel.SERVER_FINISHED , handshakeHash ); PrfSha256.ExpandSecret( writer.Slice(0, Finished.Size) , this.nextEpoch.MasterSecret , PrfLabel.CLIENT_FINISHED , handshakeHash ); writer = writer.Slice(Finished.Size); // Protect the ClientKeyExchange record this.currentEpoch.RecordProtection.EncryptClientPlaintext( packet.Slice(Record.Size, keyExchangeRecord.Length), packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length), ref keyExchangeRecord ); // Protect the ChangeCipherSpec record this.currentEpoch.RecordProtection.EncryptClientPlaintext( startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length), startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size), ref changeCipherSpecRecord ); // Protect the Finished record this.nextEpoch.RecordProtection.EncryptClientPlaintext( startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length), startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length), ref finishedRecord ); this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec; this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout; #if DEBUG if (DropClientKeyExchangeFlight()) { return; } #endif base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length); } protected virtual bool DropClientKeyExchangeFlight() { return false; } } }