aboutsummaryrefslogtreecommitdiff
path: root/Tools/Hazel-Networking/Hazel.UnitTests
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/Hazel-Networking/Hazel.UnitTests')
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs37
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs255
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs70
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs297
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs205
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs1192
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs27
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs226
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj16
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs689
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs213
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs292
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs159
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs74
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs337
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs66
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs786
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs54
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs60
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs514
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs116
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs489
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs99
23 files changed, 6273 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs
new file mode 100644
index 0000000..d6ba247
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs
@@ -0,0 +1,37 @@
+using Hazel.Udp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Threading;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class BroadcastTests
+ {
+ [TestMethod]
+ public void CanStart()
+ {
+ const string TestData = "pwerowerower";
+
+ using (UdpBroadcaster caster = new UdpBroadcaster(47777))
+ using (UdpBroadcastListener listener = new UdpBroadcastListener(47777))
+ {
+ listener.StartListen();
+
+ caster.SetData(TestData);
+
+ caster.Broadcast();
+ Thread.Sleep(1000);
+
+ var pkt = listener.GetPackets();
+ foreach (var p in pkt)
+ {
+ Console.WriteLine($"{p.Data} {p.Sender}");
+ Assert.AreEqual(TestData, p.Data);
+ }
+
+ Assert.IsTrue(pkt.Length >= 1);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs
new file mode 100644
index 0000000..9620973
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs
@@ -0,0 +1,255 @@
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Text;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class AesGcmTest
+ {
+ [TestMethod]
+ public void Example1()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes("");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes("3247184B 3C4F69A4 4DBCD228 87BBB418"), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes("3247184B 3C4F69A4 4DBCD228 87BBB418");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example2()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes("");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ ");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 4D5C2AF3 27CD64A6 2CF35ABD 2BA6FAB4
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 4D5C2AF3 27CD64A6 2CF35ABD 2BA6FAB4
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ "), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example3()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example4()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ ");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 64C02329 04AF398A 5B67C10B 53A5024D
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 64C02329 04AF398A 5B67C10B 53A5024D
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ "), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestReuseToDecrypt()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ "), ciphertextBytes);
+
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ ");
+ plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestPlaintextSmallerThanBlock()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] originalPlaintext = Encoding.UTF8.GetBytes("Lorem ipsum");
+ Assert.IsTrue(originalPlaintext.Length < 16);
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = new byte[originalPlaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertext, nonce, originalPlaintext, ByteSpan.Empty);
+
+ byte[] plaintext = new byte[originalPlaintext.Length];
+ bool result = aes.Open(plaintext, nonce, ciphertext, ByteSpan.Empty);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(originalPlaintext, plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestPlaintextLargerThanBlockMultiple()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] originalPlaintext = Encoding.UTF8.GetBytes("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.");
+ Assert.IsTrue(originalPlaintext.Length > 16);
+ Assert.IsTrue((originalPlaintext.Length % 16) != 0);
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = new byte[originalPlaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertext, nonce, originalPlaintext, ByteSpan.Empty);
+
+ byte[] plaintext = new byte[originalPlaintext.Length];
+ bool result = aes.Open(plaintext, nonce, ciphertext, ByteSpan.Empty);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(originalPlaintext, plaintext);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs
new file mode 100644
index 0000000..5309988
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs
@@ -0,0 +1,70 @@
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Text;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class Sha256Tests
+ {
+ [TestMethod]
+ public void TestOneBlockMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ "abc"
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "ba7816bf 8f01cfea 414140de 5dae2223 b00361a3 96177a9c b410ff61 f20015ad"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+
+ [TestMethod]
+ public void TestMultiBlockMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "248d6a61 d20638b8 e5c02693 0c3e6039 a33ce459 64ff2167 f6ecedd4 19db06c1"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+
+ [TestMethod]
+ public void TestLongMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ new string('a', 1000000)
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "cdc76e5c 9914fb92 81a1c7e2 84d73e67 f1809a48 a497200e 046d39cc c7112cd0"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs
new file mode 100644
index 0000000..8d9a583
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs
@@ -0,0 +1,297 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class X25519Tests
+ {
+ private static readonly byte[][] LowOrderPoints = new byte[][]
+ {
+ new byte[]{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ new byte[]{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ new byte[]{0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f, 0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16, 0x5f, 0x49, 0xb8, 0x00},
+ new byte[]{0x5f, 0x9c, 0x95, 0xbc, 0xa3, 0x50, 0x8c, 0x24, 0xb1, 0xd0, 0xb1, 0x55, 0x9c, 0x83, 0xef, 0x5b, 0x04, 0x44, 0x5c, 0xc4, 0x58, 0x1c, 0x8e, 0x86, 0xd8, 0x22, 0x4e, 0xdd, 0xd0, 0x9f, 0x11, 0x57},
+ new byte[]{0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ new byte[]{0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ new byte[]{0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ };
+
+ [TestMethod]
+ public void TestLowOrderPoints()
+ {
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ byte[] scalar = new byte[X25519.KeySize];
+ random.GetBytes(scalar);
+
+ for (int ii = 0, nn = LowOrderPoints.Length; ii != nn; ++ii)
+ {
+ ByteSpan output = new byte[X25519.KeySize];
+ bool result = X25519.Func(output, scalar, LowOrderPoints[ii]);
+ Assert.IsFalse(result, $"Multiplication by low order point {ii} succeeded: should have failed");
+ }
+ }
+ }
+
+ [TestMethod]
+ public void TestVectors()
+ {
+ for (int ii = 0, nn = TestVectorData.Length; ii != nn; ++ii)
+ {
+ byte[] actual = new byte[32];
+ bool result = X25519.Func(actual, TestVectorData[ii].In, TestVectorData[ii].Base);
+ Assert.IsTrue(result);
+ CollectionAssert.AreEqual(TestVectorData[ii].Expect, actual, $"Test vector {ii} mismatch");
+ }
+ }
+
+ [TestMethod]
+ public void TestAgreement()
+ {
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ byte[] clientPrivateKey = new byte[X25519.KeySize];
+ random.GetBytes(clientPrivateKey);
+
+ byte[] clientPublicKey = new byte[X25519.KeySize];
+ X25519.Func(clientPublicKey, clientPrivateKey);
+
+ byte[] serverPrivateKey = new byte[X25519.KeySize];
+ random.GetBytes(serverPrivateKey);
+
+ byte[] serverPublickey = new byte[X25519.KeySize];
+ X25519.Func(serverPublickey, serverPrivateKey);
+
+ // client key aggreement
+ byte[] clientSharedSecret = new byte[X25519.KeySize];
+ Assert.IsTrue(X25519.Func(clientSharedSecret, clientPrivateKey, serverPublickey));
+
+ // server key agreement
+ byte[] serverSharedSecret = new byte[X25519.KeySize];
+ Assert.IsTrue(X25519.Func(serverSharedSecret, serverPrivateKey, clientPublicKey));
+
+ CollectionAssert.AreEqual(clientSharedSecret, serverSharedSecret);
+ }
+ }
+
+ private struct TestVector
+ {
+ public byte[] In;
+ public byte[] Base;
+ public byte[] Expect;
+ }
+
+ private static readonly TestVector[] TestVectorData =
+ {
+ new TestVector {
+ In = new byte[]{0x66, 0x8f, 0xb9, 0xf7, 0x6a, 0xd9, 0x71, 0xc8, 0x1a, 0xc9, 0x0, 0x7, 0x1a, 0x15, 0x60, 0xbc, 0xe2, 0xca, 0x0, 0xca, 0xc7, 0xe6, 0x7a, 0xf9, 0x93, 0x48, 0x91, 0x37, 0x61, 0x43, 0x40, 0x14},
+ Base = new byte[]{0xdb, 0x5f, 0x32, 0xb7, 0xf8, 0x41, 0xe7, 0xa1, 0xa0, 0x9, 0x68, 0xef, 0xfd, 0xed, 0x12, 0x73, 0x5f, 0xc4, 0x7a, 0x3e, 0xb1, 0x3b, 0x57, 0x9a, 0xac, 0xad, 0xea, 0xe8, 0x9, 0x39, 0xa7, 0xdd},
+ Expect = new byte[]{0x9, 0xd, 0x85, 0xe5, 0x99, 0xea, 0x8e, 0x2b, 0xee, 0xb6, 0x13, 0x4, 0xd3, 0x7b, 0xe1, 0xe, 0xc5, 0xc9, 0x5, 0xf9, 0x92, 0x7d, 0x32, 0xf4, 0x2a, 0x9a, 0xa, 0xfb, 0x3e, 0xb, 0x40, 0x74},
+ },
+ new TestVector {
+ In = new byte[]{ 0x63, 0x66, 0x95, 0xe3, 0x4f, 0x75, 0xb9, 0xa2, 0x79, 0xc8, 0x70, 0x6f, 0xad, 0x12, 0x89, 0xf2, 0xc0, 0xb1, 0xe2, 0x2e, 0x16, 0xf8, 0xb8, 0x86, 0x17, 0x29, 0xc1, 0xa, 0x58, 0x29, 0x58, 0xaf},
+ Base = new byte[]{ 0x9, 0xd, 0x7, 0x1, 0xf8, 0xfd, 0xe2, 0x8f, 0x70, 0x4, 0x3b, 0x83, 0xf2, 0x34, 0x62, 0x25, 0x41, 0x9b, 0x18, 0xa7, 0xf2, 0x7e, 0x9e, 0x3d, 0x2b, 0xfd, 0x4, 0xe1, 0xf, 0x3d, 0x21, 0x3e},
+ Expect = new byte[]{ 0xbf, 0x26, 0xec, 0x7e, 0xc4, 0x13, 0x6, 0x17, 0x33, 0xd4, 0x40, 0x70, 0xea, 0x67, 0xca, 0xb0, 0x2a, 0x85, 0xdc, 0x1b, 0xe8, 0xcf, 0xe1, 0xff, 0x73, 0xd5, 0x41, 0xcc, 0x8, 0x32, 0x55, 0x6},
+ },
+ new TestVector {
+ In = new byte[]{ 0x73, 0x41, 0x81, 0xcd, 0x1a, 0x94, 0x6, 0x52, 0x2a, 0x56, 0xfe, 0x25, 0xe4, 0x3e, 0xcb, 0xf0, 0x29, 0x5d, 0xb5, 0xdd, 0xd0, 0x60, 0x9b, 0x3c, 0x2b, 0x4e, 0x79, 0xc0, 0x6f, 0x8b, 0xd4, 0x6d},
+ Base = new byte[]{ 0xf8, 0xa8, 0x42, 0x1c, 0x7d, 0x21, 0xa9, 0x2d, 0xb3, 0xed, 0xe9, 0x79, 0xe1, 0xfa, 0x6a, 0xcb, 0x6, 0x2b, 0x56, 0xb1, 0x88, 0x5c, 0x71, 0xc5, 0x11, 0x53, 0xcc, 0xb8, 0x80, 0xac, 0x73, 0x15},
+ Expect = new byte[]{ 0x11, 0x76, 0xd0, 0x16, 0x81, 0xf2, 0xcf, 0x92, 0x9d, 0xa2, 0xc7, 0xa3, 0xdf, 0x66, 0xb5, 0xd7, 0x72, 0x9f, 0xd4, 0x22, 0x22, 0x6f, 0xd6, 0x37, 0x42, 0x16, 0xbf, 0x7e, 0x2, 0xfd, 0xf, 0x62},
+ },
+ new TestVector {
+ In = new byte[]{ 0x1f, 0x70, 0x39, 0x1f, 0x6b, 0xa8, 0x58, 0x12, 0x94, 0x13, 0xbd, 0x80, 0x1b, 0x12, 0xac, 0xbf, 0x66, 0x23, 0x62, 0x82, 0x5c, 0xa2, 0x50, 0x9c, 0x81, 0x87, 0x59, 0xa, 0x2b, 0xe, 0x61, 0x72},
+ Base = new byte[]{ 0xd3, 0xea, 0xd0, 0x7a, 0x0, 0x8, 0xf4, 0x45, 0x2, 0xd5, 0x80, 0x8b, 0xff, 0xc8, 0x97, 0x9f, 0x25, 0xa8, 0x59, 0xd5, 0xad, 0xf4, 0x31, 0x2e, 0xa4, 0x87, 0x48, 0x9c, 0x30, 0xe0, 0x1b, 0x3b},
+ Expect = new byte[]{ 0xf8, 0x48, 0x2f, 0x2e, 0x9e, 0x58, 0xbb, 0x6, 0x7e, 0x86, 0xb2, 0x87, 0x24, 0xb3, 0xc0, 0xa3, 0xbb, 0xb5, 0x7, 0x3e, 0x4c, 0x6a, 0xcd, 0x93, 0xdf, 0x54, 0x5e, 0xff, 0xdb, 0xba, 0x50, 0x5f},
+ },
+ new TestVector {
+ In = new byte[]{ 0x3a, 0x7a, 0xe6, 0xcf, 0x8b, 0x88, 0x9d, 0x2b, 0x7a, 0x60, 0xa4, 0x70, 0xad, 0x6a, 0xd9, 0x99, 0x20, 0x6b, 0xf5, 0x7d, 0x90, 0x30, 0xdd, 0xf7, 0xf8, 0x68, 0xc, 0x8b, 0x1a, 0x64, 0x5d, 0xaa},
+ Base = new byte[]{ 0x4d, 0x25, 0x4c, 0x80, 0x83, 0xd8, 0x7f, 0x1a, 0x9b, 0x3e, 0xa7, 0x31, 0xef, 0xcf, 0xf8, 0xa6, 0xf2, 0x31, 0x2d, 0x6f, 0xed, 0x68, 0xe, 0xf8, 0x29, 0x18, 0x51, 0x61, 0xc8, 0xfc, 0x50, 0x60},
+ Expect = new byte[]{ 0x47, 0xb3, 0x56, 0xd5, 0x81, 0x8d, 0xe8, 0xef, 0xac, 0x77, 0x4b, 0x71, 0x4c, 0x42, 0xc4, 0x4b, 0xe6, 0x85, 0x23, 0xdd, 0x57, 0xdb, 0xd7, 0x39, 0x62, 0xd5, 0xa5, 0x26, 0x31, 0x87, 0x62, 0x37},
+ },
+ new TestVector {
+ In = new byte[]{ 0x20, 0x31, 0x61, 0xc3, 0x15, 0x9a, 0x87, 0x6a, 0x2b, 0xea, 0xec, 0x29, 0xd2, 0x42, 0x7f, 0xb0, 0xc7, 0xc3, 0xd, 0x38, 0x2c, 0xd0, 0x13, 0xd2, 0x7c, 0xc3, 0xd3, 0x93, 0xdb, 0xd, 0xaf, 0x6f},
+ Base = new byte[]{ 0x6a, 0xb9, 0x5d, 0x1a, 0xbe, 0x68, 0xc0, 0x9b, 0x0, 0x5c, 0x3d, 0xb9, 0x4, 0x2c, 0xc9, 0x1a, 0xc8, 0x49, 0xf7, 0xe9, 0x4a, 0x2a, 0x4a, 0x9b, 0x89, 0x36, 0x78, 0x97, 0xb, 0x7b, 0x95, 0xbf},
+ Expect = new byte[]{ 0x11, 0xed, 0xae, 0xdc, 0x95, 0xff, 0x78, 0xf5, 0x63, 0xa1, 0xc8, 0xf1, 0x55, 0x91, 0xc0, 0x71, 0xde, 0xa0, 0x92, 0xb4, 0xd7, 0xec, 0xaa, 0xc8, 0xe0, 0x38, 0x7b, 0x5a, 0x16, 0xc, 0x4e, 0x5d},
+ },
+ new TestVector {
+ In = new byte[]{ 0x13, 0xd6, 0x54, 0x91, 0xfe, 0x75, 0xf2, 0x3, 0xa0, 0x8, 0xb4, 0x41, 0x5a, 0xbc, 0x60, 0xd5, 0x32, 0xe6, 0x95, 0xdb, 0xd2, 0xf1, 0xe8, 0x3, 0xac, 0xcb, 0x34, 0xb2, 0xb7, 0x2c, 0x3d, 0x70},
+ Base = new byte[]{ 0x2e, 0x78, 0x4e, 0x4, 0xca, 0x0, 0x73, 0x33, 0x62, 0x56, 0xa8, 0x39, 0x25, 0x5e, 0xd2, 0xf7, 0xd4, 0x79, 0x6a, 0x64, 0xcd, 0xc3, 0x7f, 0x1e, 0xb0, 0xe5, 0xc4, 0xc8, 0xd1, 0xd1, 0xe0, 0xf5},
+ Expect = new byte[]{ 0x56, 0x3e, 0x8c, 0x9a, 0xda, 0xa7, 0xd7, 0x31, 0x1, 0xb0, 0xf2, 0xea, 0xd3, 0xca, 0xe1, 0xea, 0x5d, 0x8f, 0xcd, 0x5c, 0xd3, 0x60, 0x80, 0xbb, 0x8e, 0x6e, 0xc0, 0x3d, 0x61, 0x45, 0x9, 0x17},
+ },
+ new TestVector {
+ In = new byte[]{ 0x68, 0x6f, 0x7d, 0xa9, 0x3b, 0xf2, 0x68, 0xe5, 0x88, 0x6, 0x98, 0x31, 0xf0, 0x47, 0x16, 0x3f, 0x33, 0x58, 0x99, 0x89, 0xd0, 0x82, 0x6e, 0x98, 0x8, 0xfb, 0x67, 0x8e, 0xd5, 0x7e, 0x67, 0x49},
+ Base = new byte[]{ 0x8b, 0x54, 0x9b, 0x2d, 0xf6, 0x42, 0xd3, 0xb2, 0x5f, 0xe8, 0x38, 0xf, 0x8c, 0xc4, 0x37, 0x5f, 0x99, 0xb7, 0xbb, 0x4d, 0x27, 0x5f, 0x77, 0x9f, 0x3b, 0x7c, 0x81, 0xb8, 0xa2, 0xbb, 0xc1, 0x29},
+ Expect = new byte[]{ 0x1, 0x47, 0x69, 0x65, 0x42, 0x6b, 0x61, 0x71, 0x74, 0x9a, 0x8a, 0xdd, 0x92, 0x35, 0x2, 0x5c, 0xe5, 0xf5, 0x57, 0xfe, 0x40, 0x9, 0xf7, 0x39, 0x30, 0x44, 0xeb, 0xbb, 0x8a, 0xe9, 0x52, 0x79},
+ },
+ new TestVector {
+ In = new byte[]{ 0x82, 0xd6, 0x1c, 0xce, 0xdc, 0x80, 0x6a, 0x60, 0x60, 0xa3, 0x34, 0x9a, 0x5e, 0x87, 0xcb, 0xc7, 0xac, 0x11, 0x5e, 0x4f, 0x87, 0x77, 0x62, 0x50, 0xae, 0x25, 0x60, 0x98, 0xa7, 0xc4, 0x49, 0x59},
+ Base = new byte[]{ 0x8b, 0x6b, 0x9d, 0x8, 0xf6, 0x1f, 0xc9, 0x1f, 0xe8, 0xb3, 0x29, 0x53, 0xc4, 0x23, 0x40, 0xf0, 0x7, 0xb5, 0x71, 0xdc, 0xb0, 0xa5, 0x6d, 0x10, 0x72, 0x4e, 0xce, 0xf9, 0x95, 0xc, 0xfb, 0x25},
+ Expect = new byte[]{ 0x9c, 0x49, 0x94, 0x1f, 0x9c, 0x4f, 0x18, 0x71, 0xfa, 0x40, 0x91, 0xfe, 0xd7, 0x16, 0xd3, 0x49, 0x99, 0xc9, 0x52, 0x34, 0xed, 0xf2, 0xfd, 0xfb, 0xa6, 0xd1, 0x4a, 0x5a, 0xfe, 0x9e, 0x5, 0x58},
+ },
+ new TestVector {
+ In = new byte[]{ 0x7d, 0xc7, 0x64, 0x4, 0x83, 0x13, 0x97, 0xd5, 0x88, 0x4f, 0xdf, 0x6f, 0x97, 0xe1, 0x74, 0x4c, 0x9e, 0xb1, 0x18, 0xa3, 0x1a, 0x7b, 0x23, 0xf8, 0xd7, 0x9f, 0x48, 0xce, 0x9c, 0xad, 0x15, 0x4b},
+ Base = new byte[]{ 0x1a, 0xcd, 0x29, 0x27, 0x84, 0xf4, 0x79, 0x19, 0xd4, 0x55, 0xf8, 0x87, 0x44, 0x83, 0x58, 0x61, 0xb, 0xb9, 0x45, 0x96, 0x70, 0xeb, 0x99, 0xde, 0xe4, 0x60, 0x5, 0xf6, 0x89, 0xca, 0x5f, 0xb6},
+ Expect = new byte[]{ 0x0, 0xf4, 0x3c, 0x2, 0x2e, 0x94, 0xea, 0x38, 0x19, 0xb0, 0x36, 0xae, 0x2b, 0x36, 0xb2, 0xa7, 0x61, 0x36, 0xaf, 0x62, 0x8a, 0x75, 0x1f, 0xe5, 0xd0, 0x1e, 0x3, 0xd, 0x44, 0x25, 0x88, 0x59},
+ },
+ new TestVector {
+ In = new byte[]{ 0xfb, 0xc4, 0x51, 0x1d, 0x23, 0xa6, 0x82, 0xae, 0x4e, 0xfd, 0x8, 0xc8, 0x17, 0x9c, 0x1c, 0x6, 0x7f, 0x9c, 0x8b, 0xe7, 0x9b, 0xbc, 0x4e, 0xff, 0x5c, 0xe2, 0x96, 0xc6, 0xbc, 0x1f, 0xf4, 0x45},
+ Base = new byte[]{ 0x55, 0xca, 0xff, 0x21, 0x81, 0xf2, 0x13, 0x6b, 0xe, 0xd0, 0xe1, 0xe2, 0x99, 0x44, 0x48, 0xe1, 0x6c, 0xc9, 0x70, 0x64, 0x6a, 0x98, 0x3d, 0x14, 0xd, 0xc4, 0xea, 0xb3, 0xd9, 0x4c, 0x28, 0x4e},
+ Expect = new byte[]{ 0xae, 0x39, 0xd8, 0x16, 0x53, 0x23, 0x45, 0x79, 0x4d, 0x26, 0x91, 0xe0, 0x80, 0x1c, 0xaa, 0x52, 0x5f, 0xc3, 0x63, 0x4d, 0x40, 0x2c, 0xe9, 0x58, 0xb, 0x33, 0x38, 0xb4, 0x6f, 0x8b, 0xb9, 0x72},
+ },
+ new TestVector {
+ In = new byte[]{ 0x4e, 0x6, 0xc, 0xe1, 0xc, 0xeb, 0xf0, 0x95, 0x9, 0x87, 0x16, 0xc8, 0x66, 0x19, 0xeb, 0x9f, 0x7d, 0xf6, 0x65, 0x24, 0x69, 0x8b, 0xa7, 0x98, 0x8c, 0x3b, 0x90, 0x95, 0xd9, 0xf5, 0x1, 0x34},
+ Base = new byte[]{ 0x57, 0x73, 0x3f, 0x2d, 0x86, 0x96, 0x90, 0xd0, 0xd2, 0xed, 0xae, 0xc9, 0x52, 0x3d, 0xaa, 0x2d, 0xa9, 0x54, 0x45, 0xf4, 0x4f, 0x57, 0x83, 0xc1, 0xfa, 0xec, 0x6c, 0x3a, 0x98, 0x28, 0x18, 0xf3},
+ Expect = new byte[]{ 0xa6, 0x1e, 0x74, 0x55, 0x2c, 0xce, 0x75, 0xf5, 0xe9, 0x72, 0xe4, 0x24, 0xf2, 0xcc, 0xb0, 0x9c, 0x83, 0xbc, 0x1b, 0x67, 0x1, 0x47, 0x48, 0xf0, 0x2c, 0x37, 0x1a, 0x20, 0x9e, 0xf2, 0xfb, 0x2c},
+ },
+ new TestVector {
+ In = new byte[]{ 0x5c, 0x49, 0x2c, 0xba, 0x2c, 0xc8, 0x92, 0x48, 0x8a, 0x9c, 0xeb, 0x91, 0x86, 0xc2, 0xaa, 0xc2, 0x2f, 0x1, 0x5b, 0xf3, 0xef, 0x8d, 0x3e, 0xcc, 0x9c, 0x41, 0x76, 0x97, 0x62, 0x61, 0xaa, 0xb1},
+ Base = new byte[]{ 0x67, 0x97, 0xc2, 0xe7, 0xdc, 0x92, 0xcc, 0xbe, 0x7c, 0x5, 0x6b, 0xec, 0x35, 0xa, 0xb6, 0xd3, 0xbd, 0x2a, 0x2c, 0x6b, 0xc5, 0xa8, 0x7, 0xbb, 0xca, 0xe1, 0xf6, 0xc2, 0xaf, 0x80, 0x36, 0x44},
+ Expect = new byte[]{ 0xfc, 0xf3, 0x7, 0xdf, 0xbc, 0x19, 0x2, 0xb, 0x28, 0xa6, 0x61, 0x8c, 0x6c, 0x62, 0x2f, 0x31, 0x7e, 0x45, 0x96, 0x7d, 0xac, 0xf4, 0xae, 0x4a, 0xa, 0x69, 0x9a, 0x10, 0x76, 0x9f, 0xde, 0x14},
+ },
+ new TestVector {
+ In = new byte[]{ 0xea, 0x33, 0x34, 0x92, 0x96, 0x5, 0x5a, 0x4e, 0x8b, 0x19, 0x2e, 0x3c, 0x23, 0xc5, 0xf4, 0xc8, 0x44, 0x28, 0x2a, 0x3b, 0xfc, 0x19, 0xec, 0xc9, 0xdc, 0x64, 0x6a, 0x42, 0xc3, 0x8d, 0xc2, 0x48},
+ Base = new byte[]{ 0x2c, 0x75, 0xd8, 0x51, 0x42, 0xec, 0xad, 0x3e, 0x69, 0x44, 0x70, 0x4, 0x54, 0xc, 0x1c, 0x23, 0x54, 0x8f, 0xc8, 0xf4, 0x86, 0x25, 0x1b, 0x8a, 0x19, 0x46, 0x3f, 0x3d, 0xf6, 0xf8, 0xac, 0x61},
+ Expect = new byte[]{ 0x5d, 0xca, 0xb6, 0x89, 0x73, 0xf9, 0x5b, 0xd3, 0xae, 0x4b, 0x34, 0xfa, 0xb9, 0x49, 0xfb, 0x7f, 0xb1, 0x5a, 0xf1, 0xd8, 0xca, 0xe2, 0x8c, 0xd6, 0x99, 0xf9, 0xc1, 0xaa, 0x33, 0x37, 0x34, 0x2f},
+ },
+ new TestVector {
+ In = new byte[]{ 0x4f, 0x29, 0x79, 0xb1, 0xec, 0x86, 0x19, 0xe4, 0x5c, 0xa, 0xb, 0x2b, 0x52, 0x9, 0x34, 0x54, 0x1a, 0xb9, 0x44, 0x7, 0xb6, 0x4d, 0x19, 0xa, 0x76, 0xf3, 0x23, 0x14, 0xef, 0xe1, 0x84, 0xe7},
+ Base = new byte[]{ 0xf7, 0xca, 0xe1, 0x8d, 0x8d, 0x36, 0xa7, 0xf5, 0x61, 0x17, 0xb8, 0xb7, 0xe, 0x25, 0x52, 0x27, 0x7f, 0xfc, 0x99, 0xdf, 0x87, 0x56, 0xb5, 0xe1, 0x38, 0xbf, 0x63, 0x68, 0xbc, 0x87, 0xf7, 0x4c},
+ Expect = new byte[]{ 0xe4, 0xe6, 0x34, 0xeb, 0xb4, 0xfb, 0x66, 0x4f, 0xe8, 0xb2, 0xcf, 0xa1, 0x61, 0x5f, 0x0, 0xe6, 0x46, 0x6f, 0xff, 0x73, 0x2c, 0xe1, 0xf8, 0xa0, 0xc8, 0xd2, 0x72, 0x74, 0x31, 0xd1, 0x6f, 0x14},
+ },
+ new TestVector {
+ In = new byte[]{ 0xf5, 0xd8, 0xa9, 0x27, 0x90, 0x1d, 0x4f, 0xa4, 0x24, 0x90, 0x86, 0xb7, 0xff, 0xec, 0x24, 0xf5, 0x29, 0x7d, 0x80, 0x11, 0x8e, 0x4a, 0xc9, 0xd3, 0xfc, 0x9a, 0x82, 0x37, 0x95, 0x1e, 0x3b, 0x7f},
+ Base = new byte[]{ 0x3c, 0x23, 0x5e, 0xdc, 0x2, 0xf9, 0x11, 0x56, 0x41, 0xdb, 0xf5, 0x16, 0xd5, 0xde, 0x8a, 0x73, 0x5d, 0x6e, 0x53, 0xe2, 0x2a, 0xa2, 0xac, 0x14, 0x36, 0x56, 0x4, 0x5f, 0xf2, 0xe9, 0x52, 0x49},
+ Expect = new byte[]{ 0xab, 0x95, 0x15, 0xab, 0x14, 0xaf, 0x9d, 0x27, 0xe, 0x1d, 0xae, 0xc, 0x56, 0x80, 0xcb, 0xc8, 0x88, 0xb, 0xd8, 0xa8, 0xe7, 0xeb, 0x67, 0xb4, 0xda, 0x42, 0xa6, 0x61, 0x96, 0x1e, 0xfc, 0xb},
+ },
+ };
+ }
+
+ [TestClass]
+ public class X25519FieldTests
+ {
+ private readonly byte[] A = {0x21, 0xDD, 0xB0, 0x43, 0xCF, 0xB2, 0xB3, 0xFE, 0xC4, 0xCC, 0xA3, 0x8B, 0xBA, 0x3D, 0xE1, 0x92, 0xDF, 0xEA, 0x85, 0xCE, 0x2B, 0x4A, 0xD8, 0x44, 0x95, 0xAA, 0xB1, 0x3A, 0x5B, 0x62, 0x87, 0x8E};
+ private readonly byte[] B = {0x22, 0x8B, 0x2F, 0x48, 0x4B, 0x86, 0xAC, 0x9C, 0xA4, 0x7B, 0x64, 0xC4, 0x62, 0x76, 0x34, 0x7C, 0x67, 0xBD, 0x59, 0x6F, 0x8D, 0x18, 0x41, 0x4D, 0x96, 0x31, 0xA5, 0x5B, 0x3B, 0xA5, 0x7E, 0xC7};
+
+ [TestMethod]
+ public void Zero()
+ {
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement fe = X25519.FieldElement.Zero();
+ fe.CopyTo(actual);
+
+ byte[] expected = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void One()
+ {
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement fe = X25519.FieldElement.One();
+ fe.CopyTo(actual);
+
+ byte[] expected = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Add()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Add(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x43, 0x68, 0xE0, 0x8B, 0x1A, 0x39, 0x60, 0x9B, 0x69, 0x48, 0x08, 0x50, 0x1D, 0xB4, 0x15, 0x0F, 0x47, 0xA8, 0xDF, 0x3D, 0xB9, 0x62, 0x19, 0x92, 0x2B, 0xDC, 0x56, 0x96, 0x96, 0x07, 0x06, 0x56};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Sub()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Sub(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0xEC, 0x51, 0x81, 0xFB, 0x83, 0x2C, 0x07, 0x62, 0x20, 0x51, 0x3F, 0xC7, 0x57, 0xC7, 0xAC, 0x16, 0x78, 0x2D, 0x2C, 0x5F, 0x9E, 0x31, 0x97, 0xF7, 0xFE, 0x78, 0x0C, 0xDF, 0x1F, 0xBD, 0x08, 0x47};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Multiply()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Multiply(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x1E, 0xBE, 0xBD, 0xE0, 0xEC, 0xB1, 0x3C, 0xDB, 0x50, 0x6E, 0xD6, 0x50, 0x02, 0x1A, 0x59, 0x99, 0xC1, 0xC0, 0xFC, 0xE0, 0xBF, 0xDB, 0x64, 0xB0, 0x3E, 0xB3, 0x2D, 0x43, 0x8B, 0x66, 0x43, 0x3C};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Square()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Square(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0xAE, 0xB2, 0x22, 0xD4, 0x72, 0xF7, 0xF4, 0x09, 0xBB, 0x9A, 0xA9, 0x99, 0xEB, 0x7F, 0xC4, 0xE1, 0x4C, 0x0A, 0x53, 0xEB, 0x3C, 0xFF, 0x5C, 0xE2, 0xF6, 0x92, 0x46, 0x53, 0x29, 0xE1, 0x5D, 0x7A};
+ CollectionAssert.AreEqual(expected, actual);
+
+
+ a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement.Square(ref a, ref a);
+ a.CopyTo(actual);
+
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Multiply121666()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Multiply121666(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x65, 0x3E, 0xE9, 0x9D, 0x08, 0xAC, 0x1A, 0x17, 0x61, 0x4F, 0x2C, 0xED, 0x30, 0x0B, 0x9B, 0xCB, 0x2B, 0x63, 0x53, 0xB9, 0x7D, 0x67, 0x62, 0x11, 0x39, 0xF1, 0x50, 0xC9, 0x6C, 0xA1, 0x66, 0x72};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Invert()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Invert(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x8E, 0x66, 0x2F, 0x60, 0xFC, 0xCD, 0x3A, 0x11, 0x36, 0xF5, 0xD9, 0xE6, 0x94, 0x28, 0x04, 0x2A, 0x6B, 0x5D, 0xC4, 0x72, 0x82, 0x30, 0xF3, 0x09, 0xC0, 0x24, 0xDE, 0xCD, 0x60, 0x3F, 0x5D, 0x17};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs
new file mode 100644
index 0000000..8e71f2d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs
@@ -0,0 +1,205 @@
+using Hazel.Dtls;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+using System.Text;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class AesGcmRecordProtectedTests
+ {
+ private readonly ByteSpan masterSecret;
+ private readonly ByteSpan serverRandom;
+ private readonly ByteSpan clientRandom;
+
+ private const string TestMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
+
+ public AesGcmRecordProtectedTests()
+ {
+ this.masterSecret = new byte[48];
+ this.serverRandom = new byte[32];
+ this.clientRandom = new byte[32];
+
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ random.GetBytes(this.masterSecret.GetUnderlyingArray());
+ random.GetBytes(this.serverRandom.GetUnderlyingArray());
+ random.GetBytes(this.clientRandom.GetUnderlyingArray());
+ }
+ }
+
+ [TestMethod]
+ public void ServerCanEncryptAndDecryptData()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ recordProtection.EncryptServerPlaintext(encrypted, messageAsBytes, ref record);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ClientCanEncryptAndDecryptData()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ recordProtection.EncryptClientPlaintext(encrypted, messageAsBytes, ref record);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ServerDecryptionFailsWhenRecordModified()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record originalRecord = new Record();
+ originalRecord.ContentType = ContentType.ApplicationData;
+ originalRecord.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ originalRecord.Epoch = 1;
+ originalRecord.SequenceNumber = 124;
+ originalRecord.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[originalRecord.Length];
+ recordProtection.EncryptServerPlaintext(encrypted, messageAsBytes, ref originalRecord);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+
+ Record record = originalRecord;
+ record.ContentType = ContentType.Handshake;
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.Epoch++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.SequenceNumber++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDecryptionFailsWhenRecordModified()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record originalRecord = new Record();
+ originalRecord.ContentType = ContentType.ApplicationData;
+ originalRecord.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ originalRecord.Epoch = 1;
+ originalRecord.SequenceNumber = 124;
+ originalRecord.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[originalRecord.Length];
+ recordProtection.EncryptClientPlaintext(encrypted, messageAsBytes, ref originalRecord);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+
+ Record record = originalRecord;
+ record.ContentType = ContentType.Handshake;
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.Epoch++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.SequenceNumber++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+ }
+ }
+
+ [TestMethod]
+ public void ServerEncryptionCanoverlap()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ ByteSpan messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ messageAsBytes.CopyTo(encrypted);
+ recordProtection.EncryptServerPlaintext(encrypted, encrypted.Slice(0, messageAsBytes.Length), ref record);
+
+ ByteSpan plaintext = encrypted.Slice(0, recordProtection.GetDecryptedSize(record.Length));
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ClientEncryptionCanoverlap()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ ByteSpan messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ messageAsBytes.CopyTo(encrypted);
+ recordProtection.EncryptClientPlaintext(encrypted, encrypted.Slice(0, messageAsBytes.Length), ref record);
+
+ ByteSpan plaintext = encrypted.Slice(0, recordProtection.GetDecryptedSize(record.Length));
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs
new file mode 100644
index 0000000..55e0c91
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs
@@ -0,0 +1,1192 @@
+using Hazel.Dtls;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class DtlsConnectionTests
+ {
+ // Created with command line
+ // openssl req -newkey rsa:2048 -nodes -keyout key.pem -x509 -days 100000 -out certificate.pem
+ const string TestCertificate =
+@"-----BEGIN CERTIFICATE-----
+MIIDbTCCAlWgAwIBAgIUREHeZ36f23eBFQ1T3sJsBwHlSBEwDQYJKoZIhvcNAQEL
+BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
+GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMTAyMDIxNDE4MTBaGA8yMjk0
+MTExODE0MTgxMFowRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
+ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
+AQEBBQADggEPADCCAQoCggEBAMeHCR6Y6GFwH7ZnouxPLmqyCIJSCcfaGIuBU3k+
+MG2ZyXKhhhwclL8arx5x1cGmQFvPm5wXGKSiLFChj+bW5XN7xBAc5e9KVBCEabrr
+BY+X9r0a421Yjqn4F47IA2sQ6OygnttYIt0pgeEoQZhGvmc2ZfkELkptIHMavIsx
+B/R0tYgtquruWveIWMtr4O/AuPxkH750SO1OxwU8gj6QXSqskrxvhl9GBzAwBKaF
+W6t7yjR7eFqaGh7B55p4t5zrfYKCVgeyj5Yzr/xdvv3Q3H+0pex+JTMWrpsTaavq
+F2RZYbpTOofuiTwdWbAHnXW1aFSCCIrEdEs9X2FxB73V0fcCAwEAAaNTMFEwHQYD
+VR0OBBYEFETIkxnzoLXO2GcEgxTZgN8ypKowMB8GA1UdIwQYMBaAFETIkxnzoLXO
+2GcEgxTZgN8ypKowMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEB
+ACZl7WQec9xLTK0paBIkVUqZKucDCXQH0JC7z4ENbiRtQvWQm6xhAlDo8Tr8oUzj
+0/lft/g6wIo8dJ4jZ/iCSHnKz8qO80Gs/x5NISe9A/8Us1kq8y4nO40QW6xtQMH7
+j74pcfsGKDCaMFSQZnSc93a3ZMEuVPxdI5+qsvFIeC9xxRHUNo245eLqsJAe8s1c
+22Uoeu3gepozrPcIPAHADGr/CFp1HLkg9nFrTcatlNAF/N0PmLjmk/NIx/8h7n7Q
+5vapNkhcyCHsW8XB5ulKmF88QZ5BdvPmtSey0t/n8ru98615G5Wb4TS2MaprzYL3
+5ACeQOohFzevcQrEjjzkZAI=
+-----END CERTIFICATE-----
+";
+
+ const string TestPrivateKey =
+@"-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDHhwkemOhhcB+2
+Z6LsTy5qsgiCUgnH2hiLgVN5PjBtmclyoYYcHJS/Gq8ecdXBpkBbz5ucFxikoixQ
+oY/m1uVze8QQHOXvSlQQhGm66wWPl/a9GuNtWI6p+BeOyANrEOjsoJ7bWCLdKYHh
+KEGYRr5nNmX5BC5KbSBzGryLMQf0dLWILarq7lr3iFjLa+DvwLj8ZB++dEjtTscF
+PII+kF0qrJK8b4ZfRgcwMASmhVure8o0e3hamhoeweeaeLec632CglYHso+WM6/8
+Xb790Nx/tKXsfiUzFq6bE2mr6hdkWWG6UzqH7ok8HVmwB511tWhUggiKxHRLPV9h
+cQe91dH3AgMBAAECggEAVq+jVajHJTYqgPwLu7E3EGHi8oOj/jESAuIgGwfa0HNF
+I0lr06DTOyfjt01ruiN5yKmtCKa8LSLMMAfRVlA9BexapUl42HqphTeSHARpuRYj
+u8sHzgTwjoXb7kuVuJlzKQMroU5sbzvOUr1Dql3p8TugGA0p82nv9DJEghC+TQT2
+GnCKhsnmwE7lb8h0z3G3yxdv3yZE0X6oFzllBGCb1O6cwsDeYsv+SyjnyUwJROGz
+/VkC1+B48ALm4DhA5QIUoaRhO8vaCa7dacQTkXw1hVdLcaS9slIdXxbb9GbJvI0c
+baqimIkE02VUUpmlOIKUpf1sRXy1aJFpDSvWsTNLaQKBgQD8TrcVUF7oOhX5hQer
+qfNDFPvCBiWlT+8tnJraauaD1sJLD5jpRWPDu5dZ96CSZVpD1E3GQm+x58CSUknG
+AUHyHfEHTpzx7elVeUj7gianmVu8r9mVKtErwPLDJ4AUhMJjEPX2Plmh9GgFck78
+s2gfIvxdI+znvkH9JwGBznTIRQKBgQDKcpO7wiu025ZyPWR2p+qUh2ZlvBBr/6rg
+GxbFE5VraIS6zSDVOcxjPLc1pVZ/vM2WGbda0eziLpvsyauTXMCueBiNBRWZb5E4
+NK81IgbgZc4VWN9xA00cfOzO4Bjt6990BdOxiQQ1XOz1cN8DFTfsA81qR+nIne58
+LhL0DmFLCwKBgCwsU92FbrhVwxcmdUtWu+JYwCMeFGU283cW3f2zjZwzc1zU5D6j
+CW5xX3Q+6Hv5Bq6tcthtNUT+gDad9ZCXE8ah+1r+Jngs4Rc33tE53i6lqOwGFaAK
+GQkCBP6p4cC15ZqWk5mDHQo/0h5x/uY7OtWIuIpOCeIg60i5FYh2bvfJAoGAPQ7t
+i7V2ZSfNaksl37upPn7P3WMpOMl1if3hkjLj3+84CPcRLf4urMeFIkLpocEZ6Gl9
+KYEjBtyz3mi8vMc+veAu12lvKEXD8MXDCi1nEYri6wFQ8s7iFPOAoKxqGGgJjv6q
+6GLAyC9ssGIIgO+HXEGRVLq3wfAQG5fx03X61h0CgYEAiz3f8xiIR6PC4Gn5iMWu
+wmIDk3EnxdaA+7AwN/M037jbmKfzLxA1n8jXYM+to4SJx8Fxo7MD5EDhq6UoYmPU
+tGe4Ite2N9jzxG7xQrVuIx6Cg4t+E7uZ1eZuhbQ1WpqCXPIFOtXuc4szXfwD4Z+p
+IsdbLCwHYD3GVgk/D7NVxyU=
+-----END PRIVATE KEY-----
+";
+ private static X509Certificate2 GetCertificateForServer()
+ {
+ RSA privateKey = Utils.DecodeRSAKeyFromPEM(TestPrivateKey);
+ return new X509Certificate2(Utils.DecodePEM(TestCertificate)).CopyWithPrivateKey(privateKey);
+ }
+
+ private static X509Certificate2Collection GetCertificateForClient()
+ {
+ X509Certificate2 publicCertificate = new X509Certificate2(Utils.DecodePEM(TestCertificate));
+
+ X509Certificate2Collection clientCertificates = new X509Certificate2Collection();
+ clientCertificates.Add(publicCertificate);
+ return clientCertificates;
+ }
+
+ protected DtlsConnectionListener CreateListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ DtlsConnectionListener listener = new DtlsConnectionListener(2, endPoint, logger, ipMode);
+ listener.SetCertificate(GetCertificateForServer());
+ return listener;
+
+ }
+
+ protected DtlsUnityConnection CreateConnection(IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ DtlsUnityConnection connection = new DtlsUnityConnection(logger, endPoint, ipMode);
+ connection.SetValidServerCertificates(GetCertificateForClient());
+ return connection;
+ }
+
+ [TestMethod]
+ public void DtlsServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (var listener = (DtlsConnectionListener)CreateListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (var connection = CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ Assert.AreEqual(0, listener.PeerCount);
+ }
+ }
+
+ class MalformedDTLSListener : DtlsConnectionListener
+ {
+ public MalformedDTLSListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ : base(numWorkers, endPoint, logger, ipMode)
+ {
+ }
+
+ public void InjectPacket(ByteSpan packet, IPEndPoint peerAddress, ConnectionId connectionId)
+ {
+ MessageReader reader = MessageReader.GetSized(packet.Length);
+ reader.Length = packet.Length;
+ Array.Copy(packet.GetUnderlyingArray(), packet.Offset, reader.Buffer, reader.Offset, packet.Length);
+
+ base.ProcessIncomingMessageFromOtherThread(reader, peerAddress, connectionId);
+ }
+ }
+
+ class MalformedDTLSClient : DtlsUnityConnection
+ {
+ public Func<ClientHello, ByteSpan, ByteSpan> EncodeClientHelloCallback = Test_CompressionLengthOverrunClientHello;
+
+ public MalformedDTLSClient(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+
+ }
+
+ protected override void SendClientHello(bool isResend)
+ {
+ Test_SendClientHello(EncodeClientHelloCallback);
+ }
+
+ public static ByteSpan Test_CompressionLengthOverrunClientHello(ClientHello clientHello, ByteSpan writer)
+ {
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)ProtocolVersion.DTLS1_2);
+ writer = writer.Slice(2);
+
+ clientHello.Random.CopyTo(writer);
+ writer = writer.Slice(Hazel.Dtls.Random.Size);
+
+ // Do not encode session ids
+ writer[0] = (byte)0;
+ writer = writer.Slice(1);
+
+ writer[0] = (byte)clientHello.Cookie.Length;
+ clientHello.Cookie.CopyTo(writer.Slice(1));
+ writer = writer.Slice(1 + clientHello.Cookie.Length);
+
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)clientHello.CipherSuites.Length);
+ clientHello.CipherSuites.CopyTo(writer.Slice(2));
+ writer = writer.Slice(2 + clientHello.CipherSuites.Length);
+
+ // ============ Here is the corruption. writer[0] should be 1. ============
+ writer[0] = 255;
+ writer[1] = (byte)CompressionMethod.Null;
+ writer = writer.Slice(2);
+
+ // Extensions size
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)(6 + clientHello.SupportedCurves.Length));
+ writer = writer.Slice(2);
+
+ // Supported curves extension
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)ExtensionType.EllipticCurves);
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)(2 + clientHello.SupportedCurves.Length), 2);
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)clientHello.SupportedCurves.Length, 4);
+ clientHello.SupportedCurves.CopyTo(writer.Slice(6));
+
+ return writer;
+ }
+ }
+
+ [TestMethod]
+ public void TestMalformedApplicationData()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ IPEndPoint connectionEndPoint = ep;
+ DtlsConnectionListener.ConnectionId connectionId = new ThreadLimitedUdpConnectionListener.ConnectionId();
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (MalformedDTLSListener listener = new MalformedDTLSListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger(), ep))
+ {
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ connectionEndPoint = evt.Connection.EndPoint;
+ connectionId = ((ThreadLimitedUdpServerConnection)evt.Connection).ConnectionId;
+
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ ByteSpan data = new byte[5] { 0x01, 0x02, 0x03, 0x04, 0x05 };
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 10;
+ record.Length = (ushort)data.Length;
+
+ ByteSpan encoded = new byte[Record.Size + data.Length];
+ record.Encode(encoded);
+ data.CopyTo(encoded.Slice(Record.Size));
+
+ listener.InjectPacket(encoded, connectionEndPoint, connectionId);
+
+ // wait for the client to disconnect
+ listener.Dispose();
+ signal.WaitOne(100);
+ }
+ }
+
+ [TestMethod]
+ public void TestMalformedConnectionData()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ IPEndPoint connectionEndPoint = ep;
+ DtlsConnectionListener.ConnectionId connectionId = new ThreadLimitedUdpConnectionListener.ConnectionId();
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (MalformedDTLSClient connection = new MalformedDTLSClient(new TestLogger(), ep))
+ {
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ connectionEndPoint = evt.Connection.EndPoint;
+ connectionId = ((ThreadLimitedUdpServerConnection)evt.Connection).ConnectionId;
+
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ Assert.IsTrue(listener.ReceiveThreadRunning, "Listener should be able to handle a malformed hello packet");
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+
+ Assert.AreEqual(0, listener.PeerCount);
+
+ // wait for the client to disconnect
+ listener.Dispose();
+ signal.WaitOne(100);
+ }
+ }
+
+
+ [TestMethod]
+ public void TestReorderedHandshakePacketsConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ var logger = new TestLogger("Throttle");
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint, logger))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // HelloVerifyRequest
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, Server Certificate (Fragment)
+ // Server Cert
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Swap(0, 1));
+ listenerToConnectionThrottle.Release(3);
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ // Same flight, let's swap the ServerKeyExchange to the front
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Swap(0, 2));
+ listenerToConnectionThrottle.Release(3);
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ // Same flight, no swap we do matters as long as the ServerKeyExchange gets through.
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Reverse());
+
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(1);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+
+ [TestMethod]
+ public void TestResentClientHelloConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ var logger = new TestLogger("Throttle");
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint, logger))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // Trigger resend of HelloVerifyRequest
+ capture.DiscardPacketForLocal();
+
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, ServerCertificate
+ // ServerCertificate
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ listenerToConnectionThrottle.Release(3);
+
+ // Trigger a resend of ServerKeyExchange, ServerHelloDone
+ capture.DiscardPacketForLocal();
+
+ // From here, flush everything. We recover or not.
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(1);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void TestResentServerHelloConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // HelloVerifyRequest
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, Server Certificate
+ // Server Certificate
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.DiscardPacketForLocal();
+ listenerToConnectionThrottle.Release(2);
+
+ // Wait for the resends and recover
+ capture.AssertPacketsToLocalCountEquals(3);
+
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(3);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void TestConnectionSuccessAfterClientKeyExchangeFlightDropped()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger()))
+ using (TestDtlsHandshakeDropUnityConnection connection = new TestDtlsHandshakeDropUnityConnection(new TestLogger(), captureEndPoint))
+ {
+ connection.DropSendClientKeyExchangeFlightCount = 1;
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void PingDisconnectClientTest()
+ {
+#if DEBUG
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+ using (DtlsConnectionListener listener = (DtlsConnectionListener)CreateListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ {
+ // Adjust the ping rate to end the test faster
+ listener.NewConnection += (evt) =>
+ {
+ var conn = (ThreadLimitedUdpServerConnection)evt.Connection;
+ conn.KeepAliveInterval = 100;
+ conn.MissingPingsUntilDisconnect = 3;
+ };
+
+ listener.Start();
+
+ for (int i = 0; i < 5; ++i)
+ {
+ using (DtlsUnityConnection connection = (DtlsUnityConnection)CreateConnection(ep, new TestLogger()))
+ {
+ connection.KeepAliveInterval = 100;
+ connection.MissingPingsUntilDisconnect = 3;
+ connection.Connect();
+
+ Thread.Sleep(10);
+
+ // After connecting, quietly stop responding to all messages to fake connection loss.
+ connection.TestDropRate = 1;
+
+ Thread.Sleep(500); //Enough time for ~3 keep alive packets
+
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ listener.DisconnectOldConnections(TimeSpan.FromMilliseconds(500), null);
+
+ Assert.AreEqual(0, listener.PeerCount, "All clients disconnected, peer count should be zero.");
+ }
+#else
+ Assert.Inconclusive("Only works in DEBUG");
+#endif
+ }
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("SERVER")))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger("CLIENT")))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDisposeDisconnectTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void DtlsHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void DtlsUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void DtlsIPv4ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+
+ [TestMethod]
+ public void DtlsSessionV0ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (DtlsUnityConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ connection.HazelSessionVersion = 0;
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+
+ private class MultipleClientHelloDtlsConnection : DtlsUnityConnection
+ {
+ public MultipleClientHelloDtlsConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+ }
+
+ protected override void SendClientHello(bool isRetransmit)
+ {
+ base.SendClientHello(isRetransmit);
+ base.SendClientHello(true);
+ }
+ }
+
+
+ private class MultipleClientKeyExchangeFlightDtlsConnection : DtlsUnityConnection
+ {
+ public MultipleClientKeyExchangeFlightDtlsConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+ }
+
+ protected override void SendClientKeyExchangeFlight(bool isRetransmit)
+ {
+ base.SendClientKeyExchangeFlight(isRetransmit);
+ base.SendClientKeyExchangeFlight(true);
+ base.SendClientKeyExchangeFlight(true);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (DtlsConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (MultipleClientHelloDtlsConnection client = new MultipleClientHelloDtlsConnection(new TestLogger("Client "), new IPEndPoint(IPAddress.Loopback, 4296), IPMode.IPv4))
+ {
+ client.SetValidServerCertificates(GetCertificateForClient());
+
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+ client.Connect(null, 1000);
+
+ Thread.Sleep(2000);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+ Assert.IsTrue(client.HandshakeComplete);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple ClientKeyExchange packets.
+ /// </summary>
+ [TestMethod]
+ public void HandshakeLikeAJerkTest()
+ {
+ using (DtlsConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (MultipleClientKeyExchangeFlightDtlsConnection client = new MultipleClientKeyExchangeFlightDtlsConnection(new TestLogger("Client "), new IPEndPoint(IPAddress.Loopback, 4296), IPMode.IPv4))
+ {
+ client.SetValidServerCertificates(GetCertificateForClient());
+
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+ client.Connect();
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+ Assert.IsTrue(client.HandshakeComplete);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener2 = this.CreateListener(4, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine($"Connection: {evt.Connection.EndPoint}");
+ };
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger()))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection2 = this.CreateConnection(new IPEndPoint(IPAddress.IPv6Loopback, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection2.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection2.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void DtlsIPv6ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsUnreliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsReliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsUnreliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsReliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 500 &&
+ connection.Statistics.TotalBytesSent <= 675,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread timeoutThread = new Thread(() =>
+ {
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+ mutex.Set();
+ });
+ timeoutThread.Start();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (var listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (var connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger("Client")))
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ mutex.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger("Client")))
+ {
+ SemaphoreSlim mutex = new SemaphoreSlim(0, 100);
+ ManualResetEventSlim serverMutex = new ManualResetEventSlim(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Release();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ mutex.Release();
+
+ // This has to be on a new thread because the client will go straight from Connecting to NotConnected
+ ThreadPool.QueueUserWorkItem(_ =>
+ {
+ serverMutex.Wait(500);
+ args.Connection.Disconnect("Testing");
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ serverMutex.Set();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(5000);
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs
new file mode 100644
index 0000000..f33cd14
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs
@@ -0,0 +1,27 @@
+using Hazel.Dtls;
+using System.Net;
+
+namespace Hazel.UnitTests.Dtls
+{
+ internal class TestDtlsHandshakeDropUnityConnection : DtlsUnityConnection
+ {
+ public int DropSendClientKeyExchangeFlightCount = 0;
+
+ public TestDtlsHandshakeDropUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+
+ }
+
+ protected override bool DropClientKeyExchangeFlight()
+ {
+ if (DropSendClientKeyExchangeFlightCount > 0)
+ {
+ this.logger.WriteInfo($"Dropping SendClientKeyExchangeFlight");
+ --DropSendClientKeyExchangeFlightCount;
+ return true;
+ }
+
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs
new file mode 100644
index 0000000..33cd6dc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs
@@ -0,0 +1,226 @@
+using Hazel.Dtls;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class X25519EcdheRsaSha256Tests
+ {
+ private readonly RandomNumberGenerator random = RandomNumberGenerator.Create();
+ private readonly RSA privateKey = RSA.Create();
+ private readonly RSA publicKey;
+
+ public X25519EcdheRsaSha256Tests()
+ {
+ RSAParameters keyParameters = this.privateKey.ExportParameters(false);
+ this.publicKey = RSA.Create();
+ this.publicKey.ImportParameters(keyParameters);
+ }
+
+ [TestMethod]
+ public void SmallServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize/2 > 1);
+
+ data = new byte[expectedSize/2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void LargeServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize * 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void RandomServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void SmallClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize / 2 > 1);
+
+ data = new byte[expectedSize / 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void LargeClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize * 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void RandomClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void RandomSignatureFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ cipherSuite.EncodeServerKeyExchangeMessage(data, this.privateKey);
+ }
+
+ // overwrite signature with random data
+ byte[] randomSignature = new byte[this.privateKey.KeySize/8];
+ random.GetBytes(randomSignature);
+ new ByteSpan(randomSignature).CopyTo(new ByteSpan(data, data.Length - randomSignature.Length, randomSignature.Length));
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void VerifySignature()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ cipherSuite.EncodeServerKeyExchangeMessage(data, this.privateKey);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsTrue(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void GeneratesSameSharedKey()
+ {
+ byte[] serverSharedSecret;
+ byte[] clientSharedSecret;
+
+ using (X25519EcdheRsaSha256 serverCipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = serverCipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ byte[] serverKeyExchangeMessage = new byte[expectedSize];
+ serverCipherSuite.EncodeServerKeyExchangeMessage(serverKeyExchangeMessage, this.privateKey);
+
+ byte[] clientKeyExchange;
+
+ using (X25519EcdheRsaSha256 clientCipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ clientSharedSecret = new byte[clientCipherSuite.SharedKeySize()];
+ Assert.IsTrue(clientCipherSuite.VerifyServerMessageAndGenerateSharedKey(clientSharedSecret, serverKeyExchangeMessage, this.publicKey));
+
+ clientKeyExchange = new byte[clientCipherSuite.CalculateClientMessageSize()];
+ clientCipherSuite.EncodeClientKeyExchangeMessage(clientKeyExchange);
+ }
+
+ serverSharedSecret = new byte[serverCipherSuite.SharedKeySize()];
+ Assert.IsTrue(serverCipherSuite.VerifyClientMessageAndGenerateSharedKey(serverSharedSecret, clientKeyExchange));
+ }
+
+ CollectionAssert.AreEqual(serverSharedSecret, clientSharedSecret);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj b/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj
new file mode 100644
index 0000000..f39bede
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj
@@ -0,0 +1,16 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <TargetFramework>net472</TargetFramework>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
+ <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
+ <PackageReference Include="MSTest.TestFramework" Version="2.2.8" />
+
+ <ProjectReference Include="..\Hazel\Hazel.csproj" />
+ <PackageReference Include="Portable.BouncyCastle" Version="1.9.0" />
+ </ItemGroup>
+
+</Project>
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs
new file mode 100644
index 0000000..6cf4ba8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs
@@ -0,0 +1,689 @@
+using System;
+using System.IO;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class MessageReaderTests
+ {
+ [TestMethod]
+ public void ReadProperInt()
+ {
+ const int Test1 = int.MaxValue;
+ const int Test2 = int.MinValue;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(11, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(Test1, reader.ReadInt32());
+ Assert.AreEqual(Test2, reader.ReadInt32());
+ }
+
+ [TestMethod]
+ public void ReadProperBool()
+ {
+ const bool Test1 = true;
+ const bool Test2 = false;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(5, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadBoolean());
+ Assert.AreEqual(Test2, reader.ReadBoolean());
+
+ }
+
+ [TestMethod]
+ public void ReadProperString()
+ {
+ const string Test1 = "Hello";
+ string Test2 = new string(' ', 1024);
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.Write(string.Empty);
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadString());
+ Assert.AreEqual(Test2, reader.ReadString());
+ Assert.AreEqual(string.Empty, reader.ReadString());
+
+ }
+
+ [TestMethod]
+ public void ReadProperFloat()
+ {
+ const float Test1 = 12.34f;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.EndMessage();
+
+ Assert.AreEqual(7, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadSingle());
+ }
+
+ [TestMethod]
+ public void RemoveMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+ reader.Length = msg.Length;
+
+ var zero = reader.ReadMessage();
+
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+ two.RemoveMessage(three);
+
+ // Reader becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Unrealistic, but nice. Earlier data is not affected
+ Assert.AreEqual(Test0, zero.ReadByte());
+
+ // Continuing to read depth-first works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorksWithSendOptionNone()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ }
+
+ [TestMethod]
+ public void InsertMessageWithoutStartMessageInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.Write(TestInsert);
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ Assert.AreEqual(TestInsert, two.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWithMultipleMessagesInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+ const byte TestInsert2 = 77;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ writer.StartMessage(6);
+ writer.Write(TestInsert2);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ var insert2 = two.ReadMessage();
+ Assert.AreEqual(TestInsert2, insert2.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageMultipleInsertsWithoutReset()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte Test6 = 66;
+ const byte TestInsert = 77;
+ const byte TestInsert2 = 88;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(67);
+ msg.Write(Test6);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ MessageWriter writer2 = MessageWriter.Get(SendOption.Reliable);
+ writer2.StartMessage(6);
+ writer2.Write(TestInsert2);
+ writer2.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ // three becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Continuing to read works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = one.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ reader.InsertMessage(one, writer2);
+
+ var six = reader.ReadMessage();
+ Assert.AreEqual(Test6, six.ReadByte());
+ }
+
+ [TestMethod]
+ public void CopySubMessage()
+ {
+ const byte Test1 = 12;
+ const byte Test2 = 146;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+
+ msg.StartMessage(2);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ MessageReader handleMessage = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, handleMessage.Tag);
+
+ var parentReader = MessageReader.Get(handleMessage);
+
+ handleMessage.Recycle();
+ SetZero(handleMessage);
+
+ Assert.AreEqual(1, parentReader.Tag);
+
+ for (int i = 0; i < 5; ++i)
+ {
+
+ var reader = parentReader.ReadMessage();
+ Assert.AreEqual(2, reader.Tag);
+ Assert.AreEqual(Test1, reader.ReadByte());
+ Assert.AreEqual(Test2, reader.ReadByte());
+
+ var temp = parentReader;
+ parentReader = MessageReader.CopyMessageIntoParent(reader);
+
+ temp.Recycle();
+ SetZero(temp);
+ SetZero(reader);
+ }
+ }
+
+ [TestMethod]
+ public void ReadMessageLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(2);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, reader.Tag);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessage();
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub = reader.ReadMessage();
+ Assert.AreEqual(0, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ }
+
+ [TestMethod]
+ public void ReadMessageAsNewBufferLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(232);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, reader.Tag);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub.Recycle();
+
+ sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(0, sub.Length);
+ Assert.AreEqual(232, sub.Tag);
+ sub.Recycle();
+ }
+
+ [TestMethod]
+ public void ReadStringProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ int DataLength = TestDataFromAPreviousPacket.Length + 1;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a string from
+ // an offset, it reads the packed int at that offset, treats it
+ // as a length and then proceeds to read the data that comes after
+ // it without any bound checks. This can be chained with something
+ // else to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+
+ // This will be our malicious "string length"
+ writer.WritePacked(DataLength);
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray(includeHeader: false);
+
+ // One extra byte for the MessageWriter header, one more for the malicious data
+ Assert.AreEqual(DataLength + 1, testData.Length);
+
+ var dut = MessageReader.Get(testData);
+
+ // If Length is short by even a byte, ReadString should obey that.
+ dut.Length--;
+
+ try
+ {
+ dut.ReadString();
+ Assert.Fail("ReadString is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void ReadMessageProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ // Extra 3 bytes for the length + tag header for ReadMessage.
+ int DataLength = TestDataFromAPreviousPacket.Length + 1 + 3;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a message, it
+ // reads the uint16 at that offset, treats it as a length without any bound checks.
+ // This can be allow a later ReadString or ReadBytes to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+
+ // This is the malicious length. No data in this message, so it should be zero.
+ writer.Write((ushort)1);
+ writer.Write((byte)0); // Tag
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray(includeHeader: false);
+
+ Assert.AreEqual(DataLength, testData.Length);
+
+ var outer = MessageReader.Get(testData);
+
+ // Length is just the malicious message header.
+ outer.Length = 3;
+
+ try
+ {
+ outer.ReadMessage();
+ Assert.Fail("ReadMessage is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void GetLittleEndian()
+ {
+ Assert.IsTrue(MessageWriter.IsLittleEndian());
+ }
+
+ private void SetZero(MessageReader reader)
+ {
+ for (int i = 0; i < reader.Buffer.Length; ++i)
+ reader.Buffer[i] = 0;
+ }
+ }
+
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs
new file mode 100644
index 0000000..b292a5d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs
@@ -0,0 +1,213 @@
+using System;
+using System.IO;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class MessageWriterTests
+ {
+
+ [TestMethod]
+ public void CancelMessages()
+ {
+ var msg = new MessageWriter(128);
+
+ msg.StartMessage(1);
+ msg.Write(32);
+
+ msg.StartMessage(2);
+ msg.Write(2);
+ msg.CancelMessage();
+
+ Assert.AreEqual(7, msg.Length);
+ Assert.IsFalse(msg.HasBytes(7));
+
+ msg.CancelMessage();
+
+ Assert.AreEqual(0, msg.Length);
+ Assert.IsFalse(msg.HasBytes(1));
+ }
+
+ [TestMethod]
+ public void HasBytes()
+ {
+ var msg = new MessageWriter(128);
+
+ msg.StartMessage(1);
+ msg.Write(32);
+
+ msg.StartMessage(2);
+ msg.Write(2);
+ msg.EndMessage();
+
+ // Assert.AreEqual(7, msg.Length);
+ Assert.IsTrue(msg.HasBytes(7));
+ }
+
+ [TestMethod]
+ public void WriteProperInt()
+ {
+ const int Test1 = int.MaxValue;
+ const int Test2 = int.MinValue;
+
+ var msg = new MessageWriter(128);
+ msg.Write(Test1);
+ msg.Write(Test2);
+
+ Assert.AreEqual(8, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadInt32());
+ Assert.AreEqual(Test2, reader.ReadInt32());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperBool()
+ {
+ const bool Test1 = true;
+ const bool Test2 = false;
+
+ var msg = new MessageWriter(128);
+ msg.Write(Test1);
+ msg.Write(Test2);
+
+ Assert.AreEqual(2, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadBoolean());
+ Assert.AreEqual(Test2, reader.ReadBoolean());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperString()
+ {
+ const string Test1 = "Hello";
+ string Test2 = new string(' ', 1024);
+ var msg = new MessageWriter(2048);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.Write(string.Empty);
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadString());
+ Assert.AreEqual(Test2, reader.ReadString());
+ Assert.AreEqual(string.Empty, reader.ReadString());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperFloat()
+ {
+ const float Test1 = 12.34f;
+
+ var msg = new MessageWriter(2048);
+ msg.Write(Test1);
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadSingle());
+ }
+ }
+
+ [TestMethod]
+ public void WritePackedUint()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.WritePacked(8u);
+ msg.WritePacked(250u);
+ msg.WritePacked(68000u);
+ msg.EndMessage();
+
+ Assert.AreEqual(3 + 1 + 2 + 3, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(8u, reader.ReadPackedUInt32());
+ Assert.AreEqual(250u, reader.ReadPackedUInt32());
+ Assert.AreEqual(68000u, reader.ReadPackedUInt32());
+ }
+
+ [TestMethod]
+ public void WritePackedInt()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.WritePacked(8);
+ msg.WritePacked(250);
+ msg.WritePacked(68000);
+ msg.WritePacked(60168000);
+ msg.WritePacked(-68000);
+ msg.WritePacked(-250);
+ msg.WritePacked(-8);
+
+ msg.WritePacked(0);
+ msg.WritePacked(-1);
+ msg.WritePacked(int.MinValue);
+ msg.WritePacked(int.MaxValue);
+ msg.EndMessage();
+
+ Assert.AreEqual(3 + 1 + 2 + 3 + 4 + 5 + 5 + 5 + 1 + 5 + 5 + 5, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(8, reader.ReadPackedInt32());
+ Assert.AreEqual(250, reader.ReadPackedInt32());
+ Assert.AreEqual(68000, reader.ReadPackedInt32());
+ Assert.AreEqual(60168000, reader.ReadPackedInt32());
+
+ Assert.AreEqual(-68000, reader.ReadPackedInt32());
+ Assert.AreEqual(-250, reader.ReadPackedInt32());
+ Assert.AreEqual(-8, reader.ReadPackedInt32());
+
+ Assert.AreEqual(0, reader.ReadPackedInt32());
+ Assert.AreEqual(-1, reader.ReadPackedInt32());
+ Assert.AreEqual(int.MinValue, reader.ReadPackedInt32());
+ Assert.AreEqual(int.MaxValue, reader.ReadPackedInt32());
+ }
+
+ [TestMethod]
+ public void WritesMessageLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.EndMessage();
+
+ Assert.AreEqual(2 + 1 + 4, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(4, reader.ReadUInt16()); // Length After Type and Target
+ Assert.AreEqual(1, reader.ReadByte()); // Type
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+ }
+ }
+
+ [TestMethod]
+ public void GetLittleEndian()
+ {
+ Assert.IsTrue(MessageWriter.IsLittleEndian());
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
new file mode 100644
index 0000000..584d08c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
@@ -0,0 +1,292 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+namespace Hazel.UnitTests
+{
+ /// <summary>
+ /// Acts as an intermediate between to sockets.
+ ///
+ /// Use SendToLocalSemaphore and SendToRemoteSemaphore for
+ /// explicit control of packet flow.
+ /// </summary>
+ public class SocketCapture : IDisposable
+ {
+ private IPEndPoint localEndPoint;
+ private readonly IPEndPoint remoteEndPoint;
+
+ private Socket captureSocket;
+
+ private Thread receiveThread;
+ private Thread forLocalThread;
+ private Thread forRemoteThread;
+
+ private ILogger logger;
+
+ private readonly BlockingCollection<ByteSpan> forLocal = new BlockingCollection<ByteSpan>();
+ private readonly BlockingCollection<ByteSpan> forRemote = new BlockingCollection<ByteSpan>();
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToLocalCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForLocalCount => this.forLocal.Count;
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToRemoteCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForRemoteCount => this.forRemote.Count;
+
+ public Semaphore SendToLocalSemaphore = null;
+ public Semaphore SendToRemoteSemaphore = null;
+
+ private CancellationTokenSource cancellationSource = new CancellationTokenSource();
+ private readonly CancellationToken cancellationToken;
+
+ public SocketCapture(IPEndPoint captureEndpoint, IPEndPoint remoteEndPoint, ILogger logger = null)
+ {
+ this.logger = logger ?? new NullLogger();
+ this.cancellationToken = this.cancellationSource.Token;
+
+ this.remoteEndPoint = remoteEndPoint;
+
+ this.captureSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.captureSocket.Bind(captureEndpoint);
+
+ this.receiveThread = new Thread(this.ReceiveLoop);
+ this.receiveThread.Start();
+
+ this.forLocalThread = new Thread(this.SendToLocalLoop);
+ this.forLocalThread.Start();
+
+ this.forRemoteThread = new Thread(this.SendToRemoteLoop);
+ this.forRemoteThread.Start();
+ }
+
+ public void Dispose()
+ {
+ if (this.cancellationSource != null)
+ {
+ this.cancellationSource.Cancel();
+ this.cancellationSource.Dispose();
+ this.cancellationSource = null;
+ }
+
+ if (this.captureSocket != null)
+ {
+ this.captureSocket.Close();
+ this.captureSocket.Dispose();
+ this.captureSocket = null;
+ }
+
+ if (this.receiveThread != null)
+ {
+ this.receiveThread.Join();
+ this.receiveThread = null;
+ }
+
+ if (this.forLocalThread != null)
+ {
+ this.forLocalThread.Join();
+ this.forLocalThread = null;
+ }
+
+ if (this.forRemoteThread != null)
+ {
+ this.forRemoteThread.Join();
+ this.forRemoteThread = null;
+ }
+
+ GC.SuppressFinalize(this);
+ }
+
+ private void ReceiveLoop()
+ {
+ try
+ {
+ IPEndPoint fromEndPoint = new IPEndPoint(IPAddress.Any, 0);
+
+ for (; ; )
+ {
+ byte[] buffer = new byte[2000];
+ EndPoint endPoint = fromEndPoint;
+ int read = this.captureSocket.ReceiveFrom(buffer, ref endPoint);
+ if (read > 0)
+ {
+ // from the remote endpoint?
+ if (IPEndPoint.Equals(endPoint, remoteEndPoint))
+ {
+ this.forLocal.Add(new ByteSpan(buffer, 0, read));
+ }
+ else
+ {
+ this.localEndPoint = (IPEndPoint)endPoint;
+ this.forRemote.Add(new ByteSpan(buffer, 0, read));
+ }
+ }
+ }
+ }
+ catch (SocketException)
+ {
+ }
+ finally
+ {
+ this.forLocal.CompleteAdding();
+ this.forRemote.CompleteAdding();
+ }
+ }
+
+ private void SendToRemoteLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToRemoteSemaphore != null)
+ {
+ if (!this.SendToRemoteSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forRemote.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to remote");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.remoteEndPoint);
+ }
+ }
+ }
+
+ private void SendToLocalLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToLocalSemaphore != null)
+ {
+ if (!this.SendToLocalSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forLocal.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to local");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.localEndPoint);
+ }
+ }
+ }
+
+ public void AssertPacketsToLocalCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forLocal.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forLocal.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void AssertPacketsToRemoteCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forRemote.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forRemote.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void DiscardPacketForLocal(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forLocal.Take();
+ }
+ }
+
+ public void DiscardPacketForRemote(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forRemote.Take();
+ }
+ }
+
+ public ByteSpan PeekPacketForLocal()
+ {
+ return this.forLocal.First();
+ }
+
+ public void ReorderPacketsForRemote(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forRemote.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forRemote.Add(item);
+ }
+ }
+
+ public void ReorderPacketsForLocal(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forLocal.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forLocal.Add(item);
+ }
+ }
+
+ internal void ReleasePacketsForRemote(int numToSend)
+ {
+ var newExpected = this.forRemote.Count - numToSend;
+ this.SendToRemoteSemaphore.Release(numToSend);
+ this.AssertPacketsToRemoteCountEquals(newExpected);
+ }
+
+ internal void ReleasePacketsToLocal(int numToSend)
+ {
+ var newExpected = this.forLocal.Count - numToSend;
+ this.SendToLocalSemaphore.Release(numToSend);
+ this.AssertPacketsToLocalCountEquals(newExpected);
+ }
+
+ internal string PacketsForRemoteToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach (var item in this.forRemote)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+
+ internal string PacketsForLocalToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach(var item in this.forLocal)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs
new file mode 100644
index 0000000..8452a82
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs
@@ -0,0 +1,159 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class StatisticsTests
+ {
+ [TestMethod]
+ public void SendTests()
+ {
+ ConnectionStatistics statistics = new ConnectionStatistics();
+
+ statistics.LogUnreliableSend(10);
+
+ Assert.AreEqual(1, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(0, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(10, statistics.DataBytesSent);
+
+ statistics.LogReliableSend(5);
+
+ Assert.AreEqual(2, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(15, statistics.DataBytesSent);
+
+ statistics.LogFragmentedSend(6);
+
+ Assert.AreEqual(3, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ statistics.LogAcknowledgementSend();
+
+ Assert.AreEqual(4, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ statistics.LogHelloSend();
+
+ Assert.AreEqual(5, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(1, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ Assert.AreEqual(0, statistics.MessagesReceived);
+ Assert.AreEqual(0, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(0, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(0, statistics.DataBytesReceived);
+ Assert.AreEqual(0, statistics.TotalBytesReceived);
+
+ statistics.LogPacketSend(11);
+ Assert.AreEqual(11, statistics.TotalBytesSent);
+ }
+
+ [TestMethod]
+ public void ReceiveTests()
+ {
+ ConnectionStatistics statistics = new ConnectionStatistics();
+
+ statistics.LogUnreliableReceive(10, 11);
+
+ Assert.AreEqual(1, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(0, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(10, statistics.DataBytesReceived);
+ Assert.AreEqual(11, statistics.TotalBytesReceived);
+
+ statistics.LogReliableReceive(5, 8);
+
+ Assert.AreEqual(2, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(15, statistics.DataBytesReceived);
+ Assert.AreEqual(19, statistics.TotalBytesReceived);
+
+ statistics.LogFragmentedReceive(6, 10);
+
+ Assert.AreEqual(3, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(29, statistics.TotalBytesReceived);
+
+ statistics.LogAcknowledgementReceive(4);
+
+ Assert.AreEqual(4, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(33, statistics.TotalBytesReceived);
+
+ statistics.LogHelloReceive(7);
+
+ Assert.AreEqual(5, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(1, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(40, statistics.TotalBytesReceived);
+
+ Assert.AreEqual(0, statistics.MessagesSent);
+ Assert.AreEqual(0, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(0, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(0, statistics.DataBytesSent);
+ Assert.AreEqual(0, statistics.TotalBytesSent);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs
new file mode 100644
index 0000000..c92b0b7
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs
@@ -0,0 +1,74 @@
+using System;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+using System.Threading.Tasks;
+using Hazel.Dtls;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class StressTests
+ {
+ // [TestMethod]
+ public void StressTestOpeningConnections()
+ {
+ // Start a listener in another process, or even better,
+ // adjust the target IP and start listening on another computer.
+ var ep = new IPEndPoint(IPAddress.Loopback, 22023);
+ Parallel.For(0, 10000,
+ new ParallelOptions { MaxDegreeOfParallelism = 64 },
+ (i) => {
+
+ var connection = new UdpClientConnection(new TestLogger(), ep);
+ connection.KeepAliveInterval = 50;
+
+ connection.Connect(new byte[5]);
+ });
+ }
+
+ // This was a thing that happened to us a DDoS. Mildly instructional that we straight up ignore it.
+ public void SourceAmpAttack()
+ {
+ var localEp = new IPEndPoint(IPAddress.Any, 11710);
+ var serverEp = new IPEndPoint(IPAddress.Loopback, 11710);
+ using (ThreadLimitedUdpConnectionListener listener = new ThreadLimitedUdpConnectionListener(4, localEp, new ConsoleLogger(true)))
+ {
+ listener.Start();
+
+ Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ socket.DontFragment = false;
+
+ try
+ {
+ const int SIO_UDP_CONNRESET = -1744830452;
+ socket.IOControl(SIO_UDP_CONNRESET, new byte[1], null);
+ }
+ catch { } // Only necessary on Windows
+
+ string byteAsHex = "f23c 92d1 c277 001b 54c2 50c1 0800 4500 0035 7488 0000 3b11 2637 062f ac75 2d4f 0506 a7ea 5607 0021 5e07 ffff ffff 5453 6f75 7263 6520 456e 6769 6e65 2051 7565 7279 00";
+ byte[] bytes = StringToByteArray(byteAsHex.Replace(" ", ""));
+ socket.SendTo(bytes, serverEp);
+
+ while (socket.Poll(50000, SelectMode.SelectRead))
+ {
+ byte[] buffer = new byte[1024];
+ int len = socket.Receive(buffer);
+ Console.WriteLine($"got {len} bytes: " + string.Join(" ", buffer.Select(b => b.ToString("X"))));
+ Console.WriteLine($"got {len} bytes: " + string.Join(" ", buffer.Select(b => (char)b)));
+ }
+ }
+ }
+
+ public static byte[] StringToByteArray(string hex)
+ {
+ return Enumerable.Range(0, hex.Length / 2)
+ .Select(x => Convert.ToByte(hex.Substring(x * 2, 2), 16))
+ .ToArray();
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs b/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs
new file mode 100644
index 0000000..f3e9dfb
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs
@@ -0,0 +1,337 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Hazel;
+using System.Net;
+using System.Threading;
+using System.Diagnostics;
+using Hazel.Udp.FewerThreads;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public static class TestHelper
+ {
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerToClientTest(ThreadLimitedUdpConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ //Setup listener
+ listener.NewConnection += delegate (NewConnectionEventArgs ncArgs)
+ {
+ ncArgs.Connection.Send(data);
+ };
+
+ listener.Start();
+
+ DataReceivedEventArgs? result = null;
+ //Setup conneciton
+ connection.DataReceived += delegate (DataReceivedEventArgs a)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ try
+ {
+ result = a;
+ }
+ finally
+ {
+ mutex.Set();
+ }
+ };
+
+ connection.Connect();
+
+ //Wait until data is received
+ mutex.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerToClientTest(NetworkConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ //Setup listener
+ listener.NewConnection += delegate (NewConnectionEventArgs ncArgs)
+ {
+ ncArgs.Connection.Send(data);
+ };
+
+ listener.Start();
+
+ DataReceivedEventArgs? result = null;
+ //Setup conneciton
+ connection.DataReceived += delegate (DataReceivedEventArgs a)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ try
+ {
+ result = a;
+ }
+ finally
+ {
+ mutex.Set();
+ }
+ };
+
+ connection.Connect();
+
+ //Wait until data is received
+ mutex.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientToServerTest(NetworkConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ //Setup listener
+ DataReceivedEventArgs? result = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.DataReceived += delegate (DataReceivedEventArgs innerArgs)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ result = innerArgs;
+
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ //Connect
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ connection.Send(data);
+
+ //Wait until data is received
+ mutex2.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < data.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientToServerTest(ThreadLimitedUdpConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ //Setup listener
+ DataReceivedEventArgs? result = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.DataReceived += delegate (DataReceivedEventArgs innerArgs)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ result = innerArgs;
+
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ //Connect
+ connection.Connect();
+
+ Assert.IsTrue(mutex.WaitOne(100), "Timeout while connecting");
+
+ connection.Send(data);
+
+ //Wait until data is received
+ Assert.IsTrue(mutex2.WaitOne(100), "Timeout while sending data");
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a server disconnect test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerDisconnectTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnect("Testing");
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+ }
+
+ /// <summary>
+ /// Runs a client disconnect test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientDisconnectTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne();
+ }
+
+ /// <summary>
+ /// Ensures a client sends a disconnect packet to the server on Dispose.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientDisconnectOnDisposeTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ if (!mutex.WaitOne(TimeSpan.FromSeconds(1)))
+ {
+ Assert.Fail("Timeout waiting for client connection");
+ }
+
+ connection.Dispose();
+
+ if (!mutex2.WaitOne(TimeSpan.FromSeconds(1)))
+ {
+ Assert.Fail("Timeout waiting for client disconnect packet");
+ }
+ }
+
+ private static MessageReader ConvertToMessageReader(MessageWriter writer)
+ {
+ var output = new MessageReader();
+ output.Buffer = writer.Buffer;
+ output.Offset = writer.SendOption == SendOption.None ? 1 : 3;
+ output.Length = writer.Length - output.Offset;
+ output.Position = 0;
+
+ return output;
+ }
+
+ /// <summary>
+ /// Builds new data of increaseing value bytes.
+ /// </summary>
+ /// <param name="dataSize">The number of bytes to generate.</param>
+ /// <returns>The data.</returns>
+ static MessageWriter BuildData(SendOption sendOption, int dataSize)
+ {
+ var output = MessageWriter.Get(sendOption);
+ for (int i = 0; i < dataSize; i++)
+ {
+ output.Write((byte)i);
+ }
+
+ return output;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs b/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs
new file mode 100644
index 0000000..01ca893
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ public class TestLogger : ILogger
+ {
+ private readonly string prefix;
+
+ public TestLogger(string prefix = "")
+ {
+ this.prefix = prefix;
+ }
+
+ public void WriteVerbose(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[VERBOSE] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][VERBOSE] {msg}");
+ }
+ }
+
+ public void WriteWarning(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[WARN] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][WARN] {msg}");
+ }
+ }
+
+ public void WriteError(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[ERROR] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][ERROR] {msg}");
+ }
+ }
+
+ public void WriteInfo(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[INFO] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][INFO] {msg}");
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs
new file mode 100644
index 0000000..8b9998f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs
@@ -0,0 +1,786 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using System.Net.Sockets;
+using System.Linq;
+using System.Collections;
+using System.Collections.Generic;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class ThreadLimitedUdpConnectionTests
+ {
+ protected ThreadLimitedUdpConnectionListener CreateListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ return new ThreadLimitedUdpConnectionListener(numWorkers, endPoint, logger, ipMode);
+ }
+
+ protected UdpConnection CreateConnection(IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ return new UdpClientConnection(logger, endPoint, ipMode);
+ }
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("SERVER")))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger("CLIENT")))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDisposeDisconnectTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageResendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ using (SemaphoreSlim readLock = new SemaphoreSlim(0, 1))
+ {
+ connection.ResendTimeoutMs = 100;
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ readLock.Release();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToRemoteCountEquals(0);
+
+ const int NumberOfPacketsToResend = 4;
+ const int NumberOfTimesToResend = 3;
+ using (capture.SendToRemoteSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ for (int pktCnt = 0; pktCnt < NumberOfPacketsToResend; ++pktCnt)
+ {
+ Console.WriteLine("Send blocked pkt");
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+
+ for (int drops = 0; drops < NumberOfTimesToResend; ++drops)
+ {
+ capture.AssertPacketsToRemoteCountEquals(1);
+ capture.DiscardPacketForRemote();
+ }
+
+ capture.AssertPacketsToRemoteCountEquals(1);
+ capture.SendToRemoteSemaphore.Release(); // Actually let it send.
+
+ Assert.IsTrue(readLock.Wait(1000));
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+
+ output = null;
+ }
+ }
+
+ Assert.AreEqual(NumberOfPacketsToResend * NumberOfTimesToResend, connection.Statistics.MessagesResent);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ {
+ connection.ResendTimeoutMs = 100;
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ const int NumberOfPacketsToSend = 4;
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ for (int pktCnt = 0; pktCnt < NumberOfPacketsToSend; ++pktCnt)
+ {
+ Console.WriteLine("Send blocked pkt");
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+
+ msg.Recycle();
+
+ capture.AssertPacketsToLocalCountEquals(1);
+
+ var ack = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack[1]);
+ Assert.AreEqual(pktCnt + 1, ack[2]);
+ Assert.AreEqual(255, ack[3]);
+
+ capture.SendToLocalSemaphore.Release(); // Actually let it send.
+ capture.AssertPacketsToLocalCountEquals(0);
+ }
+ }
+
+ // +1 for Hello packet
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(1 + NumberOfPacketsToSend, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckWithDropTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ {
+ connection.ResendTimeoutMs = 10000; // No resends please
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ using (capture.SendToRemoteSemaphore = new Semaphore(0, int.MaxValue))
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ // Send 3 packets to remote
+ for (int pktCnt = 0; pktCnt < 3; ++pktCnt)
+ {
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ // Drop the middle packet
+ capture.AssertPacketsToRemoteCountEquals(3);
+ capture.ReorderPacketsForRemote(list => list.Sort(SortByPacketId.Instance));
+ Console.WriteLine(capture.PacketsForRemoteToString());
+
+ capture.ReleasePacketsForRemote(1);
+ capture.DiscardPacketForRemote();
+ capture.ReleasePacketsForRemote(1);
+
+ // Receive 2 acks
+ capture.AssertPacketsToLocalCountEquals(2);
+ capture.ReorderPacketsForLocal(list => list.Sort(SortByPacketId.Instance));
+ Console.WriteLine(capture.PacketsForLocalToString());
+
+ var ack1 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack1[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack1[1]);
+ Assert.AreEqual(1, ack1[2]);
+ Assert.AreEqual(255, ack1[3]);
+ capture.ReleasePacketsToLocal(1);
+
+ var ack2 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack2[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack2[1]);
+ Assert.AreEqual(3, ack2[2]);
+ Assert.AreEqual(254, ack2[3]); // The server is expecting packet 2
+ capture.ReleasePacketsToLocal(1);
+ }
+
+ // +1 for Hello packet, +2 for reliable
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(3, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckFillsDroppedAcksTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger("Client")))
+ {
+ connection.ResendTimeoutMs = 10000; // No resends please
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ // Send 4 packets to remote
+ for (int pktCnt = 0; pktCnt < 4; ++pktCnt)
+ {
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ // Receive 4 acks, Drop the middle two
+ capture.AssertPacketsToLocalCountEquals(4);
+ capture.ReorderPacketsForLocal(list => list.Sort(SortByPacketId.Instance));
+
+ var ack1 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack1[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack1[1]);
+ Assert.AreEqual(1, ack1[2]);
+ Assert.AreEqual(255, ack1[3]);
+ capture.ReleasePacketsToLocal(1);
+
+ capture.DiscardPacketForLocal(2);
+
+ var ack4 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack4[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack4[1]);
+ Assert.AreEqual(4, ack4[2]);
+ Assert.AreEqual(255, ack4[3]);
+ capture.ReleasePacketsToLocal(1);
+ }
+
+ // +1 for Hello packet, +4 for reliable despite the dropped ack
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(3, connection.Statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(5, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ private class SortByPacketId : IComparer<ByteSpan>
+ {
+ public static SortByPacketId Instance = new SortByPacketId();
+
+ public int Compare(ByteSpan x, ByteSpan y)
+ {
+ ushort xId = BitConverter.ToUInt16(x.GetUnderlyingArray(), 1);
+ ushort yId = BitConverter.ToUInt16(y.GetUnderlyingArray(), 1);
+ return xId.CompareTo(yId);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ Console.Write($"Client sent {connection.Statistics.TotalBytesSent} bytes ");
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+
+ using (ThreadLimitedUdpConnectionListener listener2 = this.CreateListener(4, new IPEndPoint(IPAddress.IPv6Any, 4296), new ConsoleLogger(true), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine($"Connection: {evt.Connection.EndPoint}");
+ };
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger()))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection2 = this.CreateConnection(new IPEndPoint(IPAddress.IPv6Loopback, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection2.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection2.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public virtual void KeepAliveClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread timeoutThread = new Thread(() =>
+ {
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+ mutex.Set();
+ });
+ timeoutThread.Start();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ SemaphoreSlim mutex = new SemaphoreSlim(0, 100);
+ ManualResetEventSlim serverMutex = new ManualResetEventSlim(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Release();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ mutex.Release();
+
+ // This has to be on a new thread because the client will go straight from Connecting to NotConnected
+ ThreadPool.QueueUserWorkItem(_ =>
+ {
+ serverMutex.Wait(500);
+ args.Connection.Disconnect("Testing");
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ serverMutex.Set();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(5000);
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs
new file mode 100644
index 0000000..3bee911
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs
@@ -0,0 +1,54 @@
+using System;
+using Hazel.UPnP;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ // [TestClass]
+ // TODO: These tests are super flaky because of hardware differences. Not sure what can be done.
+ public class UPnPTests
+ {
+ [TestMethod]
+ public void CanForwardPort()
+ {
+ using (UPnPHelper dut = new UPnPHelper(Logger.Instance))
+ {
+ Assert.IsTrue(dut.ForwardPort(22023, "Hazel Test"));
+ }
+ }
+
+ [TestMethod]
+ public void CanDeletePort()
+ {
+ using (UPnPHelper dut = new UPnPHelper(Logger.Instance))
+ {
+ Assert.IsTrue(dut.DeleteForwardingRule(22023));
+ }
+ }
+ }
+
+ public class Logger : ILogger
+ {
+ public static readonly ILogger Instance = new Logger();
+
+ public void WriteVerbose(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteWarning(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteError(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteInfo(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs
new file mode 100644
index 0000000..1e865a8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs
@@ -0,0 +1,60 @@
+using Hazel.Udp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ internal class UdpConnectionTestHarness : UdpConnection
+ {
+ public List<MessageReader> BytesSent = new List<MessageReader>();
+
+ public UdpConnectionTestHarness() : base(new TestLogger())
+ {
+ }
+
+ public ushort ReliableReceiveLast => this.reliableReceiveLast;
+
+
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.State = ConnectionState.Connected;
+ }
+
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connected;
+ }
+
+ protected override bool SendDisconnect(MessageWriter writer)
+ {
+ lock (this)
+ {
+ if (this.State != ConnectionState.Connected)
+ {
+ return false;
+ }
+
+ this.State = ConnectionState.NotConnected;
+ }
+
+ return true;
+ }
+
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ this.BytesSent.Add(MessageReader.Get(bytes));
+ }
+
+ public void Test_Receive(MessageWriter msg)
+ {
+ byte[] buffer = new byte[msg.Length];
+ Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length);
+
+ var data = MessageReader.Get(buffer);
+ this.HandleReceive(data, data.Length);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs
new file mode 100644
index 0000000..1ccd561
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs
@@ -0,0 +1,514 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using System.Net.Sockets;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UdpConnectionTests
+ {
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message;
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (UdpConnectionListener listener2 = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine("v6 connection: " + ((NetworkConnection)evt.Connection).GetIP4Address());
+ };
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296)))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.IPv6Loopback, 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to non-hello connections.
+ /// </summary>
+ [TestMethod]
+ public void FalseConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)32;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(1, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void PingDisconnectClientTest()
+ {
+#if DEBUG
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ // After connecting, quietly stop responding to all messages to fake connection loss.
+ Thread.Sleep(10);
+ listener.TestDropRate = 1;
+
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ Assert.AreEqual(3 * connection.MissingPingsUntilDisconnect + 4, connection.Statistics.TotalBytesSent); // + 4 for connecting overhead
+ }
+#else
+ Assert.Inconclusive("Only works in DEBUG");
+#endif
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Test that a disconnect is sent when the client is disposed.
+ /// </summary>
+ public void ClientDisconnectOnDisposeTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectOnDisposeTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ // As it turns out, the UdpConnectionListener can have an issue on loopback where the disconnect can happen before the hello confirm
+ // Tossing it on a different thread makes this test more reliable. Perhaps something to think about elsewhere though.
+ Task.Run(async () =>
+ {
+ await Task.Delay(1);
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs
new file mode 100644
index 0000000..ede7698
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs
@@ -0,0 +1,116 @@
+using System;
+using System.Collections.Generic;
+using Hazel.Udp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UdpReliabilityTests
+ {
+ [TestMethod]
+ public void TestReliableWrapOffByOne()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ Assert.AreEqual(ushort.MaxValue, dut.ReliableReceiveLast);
+
+ SetReliableId(data, 10);
+ dut.Test_Receive(data);
+
+ // This message may not be received if there is an off-by-one error when marking missed pkts up to 10.
+ SetReliableId(data, 9);
+ dut.Test_Receive(data);
+
+ // Both messages should be received.
+ Assert.AreEqual(2, messagesReceived.Count);
+ messagesReceived.Clear();
+
+ Assert.AreEqual(2, dut.BytesSent.Count);
+ dut.BytesSent.Clear();
+ }
+
+ [TestMethod]
+ public void TestThatAllMessagesAreReceived()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ for (int i = 0; i < ushort.MaxValue * 2; ++i)
+ {
+ // Send a new message, it should be received and ack'd
+ SetReliableId(data, i);
+ dut.Test_Receive(data);
+
+ // Resend an old message, it should be ignored
+ if (i > 2)
+ {
+ SetReliableId(data, i - 1);
+ dut.Test_Receive(data);
+
+ // It should still be ack'd
+ Assert.AreEqual(2, dut.BytesSent.Count);
+ dut.BytesSent.RemoveAt(1);
+ }
+
+ Assert.AreEqual(1, messagesReceived.Count);
+ messagesReceived.Clear();
+
+ Assert.AreEqual(1, dut.BytesSent.Count);
+ dut.BytesSent.Clear();
+ }
+ }
+
+ [TestMethod]
+ public void TestAcksForNotReceivedMessages()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ SetReliableId(data, 1);
+ dut.Test_Receive(data);
+
+ SetReliableId(data, 3);
+ dut.Test_Receive(data);
+
+ MessageReader ackPacket = dut.BytesSent[1];
+ // Must be ack
+ Assert.AreEqual(4, ackPacket.Length);
+
+ byte recentPackets = ackPacket.Buffer[3];
+ // Last packet was not received
+ Assert.AreEqual(0, recentPackets & 1);
+ // The packet before that was.
+ Assert.AreEqual(1, (recentPackets >> 1) & 1);
+ }
+
+ private static void SetReliableId(MessageWriter data, int i)
+ {
+ ushort id = (ushort)i;
+ data.Buffer[1] = (byte)(id >> 8);
+ data.Buffer[2] = (byte)id;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs
new file mode 100644
index 0000000..0745578
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs
@@ -0,0 +1,489 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using System.Net.Sockets;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UnityUdpConnectionTests
+ {
+ private ILogger logger = new ConsoleLogger(true);
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ using (ManualResetEventSlim mutex = new ManualResetEventSlim(false))
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ mutex.Set();
+ };
+
+ connection.Connect(TestData);
+ mutex.Wait(5000);
+
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message;
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (UdpConnectionListener listener2 = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine("v6 connection: " + ((NetworkConnection)evt.Connection).GetIP4Address());
+ };
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296)))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.IPv6Loopback, 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to non-hello connections.
+ /// </summary>
+ [TestMethod]
+ public void FalseConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)32;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(1, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Test that a disconnect is sent when the client is disposed.
+ /// </summary>
+ public void ClientDisconnectOnDisposeTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectOnDisposeTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ // As it turns out, the UdpConnectionListener can have an issue on loopback where the disconnect can happen before the hello confirm
+ // Tossing it on a different thread makes this test more reliable. Perhaps something to think about elsewhere though.
+ Task.Run(async () =>
+ {
+ await Task.Delay(1);
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs
new file mode 100644
index 0000000..df62b80
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs
@@ -0,0 +1,99 @@
+using Org.BouncyCastle.Crypto;
+using Org.BouncyCastle.Crypto.Parameters;
+using Org.BouncyCastle.OpenSsl;
+using Org.BouncyCastle.Security;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.IO;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests
+{
+ static class Utils
+ {
+ /// <summary>
+ /// Hex encode a byte array (lower case)
+ /// </summary>
+ public static string BytesToHex(byte[] data)
+ {
+ string chars = "0123456789abcdef";
+
+ StringBuilder sb = new StringBuilder(data.Length * 2);
+ for (int ii = 0, nn = data.Length; ii != nn; ++ii)
+ {
+ sb.Append(chars[data[ii] >> 4]);
+ sb.Append(chars[data[ii] & 0xF]);
+ }
+
+ return sb.ToString().ToLower();
+ }
+
+ /// <summary>
+ /// Decode a hex string to a byte array (lowercase)
+ /// </summary>
+ public static byte[] HexToBytes(string hex)
+ {
+ hex = hex.ToLower();
+ hex = hex = string.Concat(hex.Where(c => !char.IsWhiteSpace(c)));
+
+ byte[] output = new byte[hex.Length / 2];
+
+ for (int ii = 0; ii != hex.Length; ++ii)
+ {
+ byte nibble;
+
+ char c = hex[ii];
+ if (c >= 'a')
+ {
+ nibble = (byte)(0x0A + c - 'a');
+ }
+ else
+ {
+ nibble = (byte)(c - '0');
+ }
+
+ if ((ii & 1) == 0)
+ {
+ output[ii / 2] = (byte)(nibble << 4);
+ }
+ else
+ {
+ output[ii / 2] |= nibble;
+ }
+ }
+
+ return output;
+ }
+
+ public static byte[] DecodePEM(string pemData)
+ {
+ List<byte> result = new List<byte>();
+
+ pemData = pemData.Replace("\r", "");
+ string[] lines = pemData.Split('\n');
+ foreach (string line in lines)
+ {
+ if (line.StartsWith("-----"))
+ {
+ continue;
+ }
+
+ byte[] lineData = Convert.FromBase64String(line);
+ result.AddRange(lineData);
+ }
+
+ return result.ToArray();
+ }
+
+ public static RSA DecodeRSAKeyFromPEM(string pemData)
+ {
+ var pemReader = new PemReader(new StringReader(pemData));
+ var parameters = DotNetUtilities.ToRSAParameters((RsaPrivateCrtKeyParameters)pemReader.ReadObject());
+ var rsa = RSA.Create();
+ rsa.ImportParameters(parameters);
+ return rsa;
+ }
+ }
+}