aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Tools/Hazel-Networking/.gitattributes63
-rw-r--r--Tools/Hazel-Networking/.gitignore194
-rw-r--r--Tools/Hazel-Networking/Hazel Networking Protocol.docxbin0 -> 22391 bytes
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs37
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs255
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs70
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs297
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs205
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs1192
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs27
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs226
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj16
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs689
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs213
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs292
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs159
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs74
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs337
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs66
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs786
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs54
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs60
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs514
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs116
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs489
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs99
-rw-r--r--Tools/Hazel-Networking/Hazel.sln31
-rw-r--r--Tools/Hazel-Networking/Hazel/ByteSpan.cs191
-rw-r--r--Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs131
-rw-r--r--Tools/Hazel-Networking/Hazel/Connection.cs234
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionListener.cs160
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionState.cs28
-rw-r--r--Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs574
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs369
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/Const.cs82
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs36
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs49
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/IAes.cs27
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs86
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs36
-rw-r--r--Tools/Hazel-Networking/Hazel/Crypto/X25519.cs844
-rw-r--r--Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs29
-rw-r--r--Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs24
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs147
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs1424
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs1246
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs734
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs63
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs66
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs84
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/Record.cs123
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs159
-rw-r--r--Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs202
-rw-r--r--Tools/Hazel-Networking/Hazel/Extensions.cs34
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs44
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs402
-rw-r--r--Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs110
-rw-r--r--Tools/Hazel-Networking/Hazel/Hazel.csproj14
-rw-r--r--Tools/Hazel-Networking/Hazel/HazelException.cs24
-rw-r--r--Tools/Hazel-Networking/Hazel/IPMode.cs30
-rw-r--r--Tools/Hazel-Networking/Hazel/IRecyclable.cs29
-rw-r--r--Tools/Hazel-Networking/Hazel/ListenerStatistics.cs23
-rw-r--r--Tools/Hazel-Networking/Hazel/MessageReader.cs452
-rw-r--r--Tools/Hazel-Networking/Hazel/MessageWriter.cs365
-rw-r--r--Tools/Hazel-Networking/Hazel/NetworkConnection.cs117
-rw-r--r--Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs26
-rw-r--r--Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs22
-rw-r--r--Tools/Hazel-Networking/Hazel/ObjectPool.cs108
-rw-r--r--Tools/Hazel-Networking/Hazel/SendErrors.cs15
-rw-r--r--Tools/Hazel-Networking/Hazel/SendOption.cs35
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs65
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs158
-rw-r--r--Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs347
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs39
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs157
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs127
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs364
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs167
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs490
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs259
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs339
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs108
-rw-r--r--Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs353
-rw-r--r--Tools/Hazel-Networking/LICENSE22
-rw-r--r--Tools/Hazel-Networking/README.md47
86 files changed, 18685 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/.gitattributes b/Tools/Hazel-Networking/.gitattributes
new file mode 100644
index 0000000..1ff0c42
--- /dev/null
+++ b/Tools/Hazel-Networking/.gitattributes
@@ -0,0 +1,63 @@
+###############################################################################
+# Set default behavior to automatically normalize line endings.
+###############################################################################
+* text=auto
+
+###############################################################################
+# Set default behavior for command prompt diff.
+#
+# This is need for earlier builds of msysgit that does not have it on by
+# default for csharp files.
+# Note: This is only used by command line
+###############################################################################
+#*.cs diff=csharp
+
+###############################################################################
+# Set the merge driver for project and solution files
+#
+# Merging from the command prompt will add diff markers to the files if there
+# are conflicts (Merging from VS is not affected by the settings below, in VS
+# the diff markers are never inserted). Diff markers may cause the following
+# file extensions to fail to load in VS. An alternative would be to treat
+# these files as binary and thus will always conflict and require user
+# intervention with every merge. To do so, just uncomment the entries below
+###############################################################################
+#*.sln merge=binary
+#*.csproj merge=binary
+#*.vbproj merge=binary
+#*.vcxproj merge=binary
+#*.vcproj merge=binary
+#*.dbproj merge=binary
+#*.fsproj merge=binary
+#*.lsproj merge=binary
+#*.wixproj merge=binary
+#*.modelproj merge=binary
+#*.sqlproj merge=binary
+#*.wwaproj merge=binary
+
+###############################################################################
+# behavior for image files
+#
+# image files are treated as binary by default.
+###############################################################################
+#*.jpg binary
+#*.png binary
+#*.gif binary
+
+###############################################################################
+# diff behavior for common document formats
+#
+# Convert binary document formats to text before diffing them. This feature
+# is only available from the command line. Turn it on by uncommenting the
+# entries below.
+###############################################################################
+#*.doc diff=astextplain
+#*.DOC diff=astextplain
+#*.docx diff=astextplain
+#*.DOCX diff=astextplain
+#*.dot diff=astextplain
+#*.DOT diff=astextplain
+#*.pdf diff=astextplain
+#*.PDF diff=astextplain
+#*.rtf diff=astextplain
+#*.RTF diff=astextplain
diff --git a/Tools/Hazel-Networking/.gitignore b/Tools/Hazel-Networking/.gitignore
new file mode 100644
index 0000000..9bf6cfb
--- /dev/null
+++ b/Tools/Hazel-Networking/.gitignore
@@ -0,0 +1,194 @@
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+
+# User-specific files
+*.suo
+*.user
+*.sln.docstates
+.vs/
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+x64/
+build/
+bld/
+[Bb]in/
+[Oo]bj/
+
+#Sandcastle generated documentation
+[Hh]elp/
+
+# Roslyn cache directories
+*.ide/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+#NUNIT
+*.VisualState.xml
+TestResult.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+*_i.c
+*_p.c
+*_i.h
+*.ilk
+*.meta
+*.obj
+*.pch
+*.pdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp
+*.tmp_proj
+*.log
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opensdf
+*.sdf
+*.cachefile
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# JustCode is a .NET coding addin-in
+.JustCode
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+## TODO: Comment the next line if you want to checkin your
+## web deploy settings but do note that will include unencrypted
+## passwords
+#*.pubxml
+
+# NuGet Packages Directory
+packages/*
+## TODO: If the tool you use requires repositories.config
+## uncomment the next line
+#!packages/repositories.config
+
+# Enable "build/" folder in the NuGet Packages folder since
+# NuGet packages use it for MSBuild targets.
+# This line needs to be after the ignore of the build folder
+# (and the packages folder if the line above has been uncommented)
+!packages/build/
+
+# Windows Azure Build Output
+csx/
+*.build.csdef
+
+# Windows Store app package directory
+AppPackages/
+
+# Others
+sql/
+*.Cache
+ClientBin/
+[Ss]tyle[Cc]op.*
+~$*
+*~
+*.dbmdl
+*.dbproj.schemaview
+*.pfx
+*.publishsettings
+node_modules/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+
+# SQL Server files
+*.mdf
+*.ldf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# LightSwitch generated files
+GeneratedArtifacts/
+_Pvt_Extensions/
+ModelManifest.xml
+*.ide
diff --git a/Tools/Hazel-Networking/Hazel Networking Protocol.docx b/Tools/Hazel-Networking/Hazel Networking Protocol.docx
new file mode 100644
index 0000000..615c0cd
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel Networking Protocol.docx
Binary files differ
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs
new file mode 100644
index 0000000..d6ba247
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/BroadcastTests.cs
@@ -0,0 +1,37 @@
+using Hazel.Udp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Threading;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class BroadcastTests
+ {
+ [TestMethod]
+ public void CanStart()
+ {
+ const string TestData = "pwerowerower";
+
+ using (UdpBroadcaster caster = new UdpBroadcaster(47777))
+ using (UdpBroadcastListener listener = new UdpBroadcastListener(47777))
+ {
+ listener.StartListen();
+
+ caster.SetData(TestData);
+
+ caster.Broadcast();
+ Thread.Sleep(1000);
+
+ var pkt = listener.GetPackets();
+ foreach (var p in pkt)
+ {
+ Console.WriteLine($"{p.Data} {p.Sender}");
+ Assert.AreEqual(TestData, p.Data);
+ }
+
+ Assert.IsTrue(pkt.Length >= 1);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs
new file mode 100644
index 0000000..9620973
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/AesGcmTest.cs
@@ -0,0 +1,255 @@
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Text;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class AesGcmTest
+ {
+ [TestMethod]
+ public void Example1()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes("");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes("3247184B 3C4F69A4 4DBCD228 87BBB418"), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes("3247184B 3C4F69A4 4DBCD228 87BBB418");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example2()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes("");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ ");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 4D5C2AF3 27CD64A6 2CF35ABD 2BA6FAB4
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 4D5C2AF3 27CD64A6 2CF35ABD 2BA6FAB4
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ "), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example3()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void Example4()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ ");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 64C02329 04AF398A 5B67C10B 53A5024D
+ "), ciphertextBytes);
+ }
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 42831EC2 21777424 4B7221B7 84D0D49C
+ E3AA212F 2C02A4E0 35C17E23 29ACA12E
+ 21D514B2 5466931C 7D8F6A5A AC84AA05
+ 1BA30B39 6A0AAC97 3D58E091 473F5985
+
+ 64C02329 04AF398A 5B67C10B 53A5024D
+ ");
+ byte[] plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ D9313225 F88406E5 A55909C5 AFF5269A
+ 86A7A953 1534F7DA 2E4C303D 8A318A72
+ 1C3C0C95 95680953 2FCF0E24 49A6B525
+ B16AEDF5 AA0DE657 BA637B39 1AAFD255
+ "), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestReuseToDecrypt()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] associatedData = Utils.HexToBytes(@"
+ 3AD77BB4 0D7A3660 A89ECAF3 2466EF97
+ F5D3D585 03B9699D E785895A 96FDBAAF
+ 43B1CD7F 598ECE23 881B00E3 ED030688
+ 7B0C785E 27E8AD3F 82232071 04725DD4
+ ");
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] plaintext = Utils.HexToBytes("");
+ byte[] ciphertextBytes = new byte[plaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertextBytes, nonce, plaintext, associatedData);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ "), ciphertextBytes);
+
+ byte[] ciphertext = Utils.HexToBytes(@"
+ 5F91D771 23EF5EB9 99791384 9B8DC1E9
+ ");
+ plaintext = new byte[ciphertext.Length - Aes128Gcm.CiphertextOverhead];
+ bool result = aes.Open(plaintext, nonce, ciphertext, associatedData);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(Utils.HexToBytes(""), plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestPlaintextSmallerThanBlock()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] originalPlaintext = Encoding.UTF8.GetBytes("Lorem ipsum");
+ Assert.IsTrue(originalPlaintext.Length < 16);
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = new byte[originalPlaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertext, nonce, originalPlaintext, ByteSpan.Empty);
+
+ byte[] plaintext = new byte[originalPlaintext.Length];
+ bool result = aes.Open(plaintext, nonce, ciphertext, ByteSpan.Empty);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(originalPlaintext, plaintext);
+ }
+ }
+
+ [TestMethod]
+ public void TestPlaintextLargerThanBlockMultiple()
+ {
+ byte[] key = Utils.HexToBytes("FEFFE992 8665731C 6D6A8F94 67308308");
+ byte[] nonce = Utils.HexToBytes("CAFEBABE FACEDBAD DECAF888");
+ byte[] originalPlaintext = Encoding.UTF8.GetBytes("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.");
+ Assert.IsTrue(originalPlaintext.Length > 16);
+ Assert.IsTrue((originalPlaintext.Length % 16) != 0);
+
+ using (Aes128Gcm aes = new Aes128Gcm(key))
+ {
+ byte[] ciphertext = new byte[originalPlaintext.Length + Aes128Gcm.CiphertextOverhead];
+ aes.Seal(ciphertext, nonce, originalPlaintext, ByteSpan.Empty);
+
+ byte[] plaintext = new byte[originalPlaintext.Length];
+ bool result = aes.Open(plaintext, nonce, ciphertext, ByteSpan.Empty);
+ Assert.IsTrue(result);
+
+ CollectionAssert.AreEqual(originalPlaintext, plaintext);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs
new file mode 100644
index 0000000..5309988
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/Sha256Tests.cs
@@ -0,0 +1,70 @@
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Text;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class Sha256Tests
+ {
+ [TestMethod]
+ public void TestOneBlockMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ "abc"
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "ba7816bf 8f01cfea 414140de 5dae2223 b00361a3 96177a9c b410ff61 f20015ad"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+
+ [TestMethod]
+ public void TestMultiBlockMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "248d6a61 d20638b8 e5c02693 0c3e6039 a33ce459 64ff2167 f6ecedd4 19db06c1"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+
+ [TestMethod]
+ public void TestLongMessage()
+ {
+ ByteSpan message = Encoding.ASCII.GetBytes(
+ new string('a', 1000000)
+ );
+ byte[] expectedDigest = Utils.HexToBytes(
+ "cdc76e5c 9914fb92 81a1c7e2 84d73e67 f1809a48 a497200e 046d39cc c7112cd0"
+ );
+ byte[] actualDigest = new byte[Sha256Stream.DigestSize];
+
+ using (Sha256Stream sha256 = new Sha256Stream())
+ {
+ sha256.AddData(message);
+ sha256.CopyOrCalculateFinalHash(actualDigest);
+ }
+
+ CollectionAssert.AreEqual(expectedDigest, actualDigest);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs
new file mode 100644
index 0000000..8d9a583
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Crypto/X25519Tests.cs
@@ -0,0 +1,297 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+using Hazel.Crypto;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests.Crypto
+{
+ [TestClass]
+ public class X25519Tests
+ {
+ private static readonly byte[][] LowOrderPoints = new byte[][]
+ {
+ new byte[]{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ new byte[]{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ new byte[]{0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f, 0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16, 0x5f, 0x49, 0xb8, 0x00},
+ new byte[]{0x5f, 0x9c, 0x95, 0xbc, 0xa3, 0x50, 0x8c, 0x24, 0xb1, 0xd0, 0xb1, 0x55, 0x9c, 0x83, 0xef, 0x5b, 0x04, 0x44, 0x5c, 0xc4, 0x58, 0x1c, 0x8e, 0x86, 0xd8, 0x22, 0x4e, 0xdd, 0xd0, 0x9f, 0x11, 0x57},
+ new byte[]{0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ new byte[]{0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ new byte[]{0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
+ };
+
+ [TestMethod]
+ public void TestLowOrderPoints()
+ {
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ byte[] scalar = new byte[X25519.KeySize];
+ random.GetBytes(scalar);
+
+ for (int ii = 0, nn = LowOrderPoints.Length; ii != nn; ++ii)
+ {
+ ByteSpan output = new byte[X25519.KeySize];
+ bool result = X25519.Func(output, scalar, LowOrderPoints[ii]);
+ Assert.IsFalse(result, $"Multiplication by low order point {ii} succeeded: should have failed");
+ }
+ }
+ }
+
+ [TestMethod]
+ public void TestVectors()
+ {
+ for (int ii = 0, nn = TestVectorData.Length; ii != nn; ++ii)
+ {
+ byte[] actual = new byte[32];
+ bool result = X25519.Func(actual, TestVectorData[ii].In, TestVectorData[ii].Base);
+ Assert.IsTrue(result);
+ CollectionAssert.AreEqual(TestVectorData[ii].Expect, actual, $"Test vector {ii} mismatch");
+ }
+ }
+
+ [TestMethod]
+ public void TestAgreement()
+ {
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ byte[] clientPrivateKey = new byte[X25519.KeySize];
+ random.GetBytes(clientPrivateKey);
+
+ byte[] clientPublicKey = new byte[X25519.KeySize];
+ X25519.Func(clientPublicKey, clientPrivateKey);
+
+ byte[] serverPrivateKey = new byte[X25519.KeySize];
+ random.GetBytes(serverPrivateKey);
+
+ byte[] serverPublickey = new byte[X25519.KeySize];
+ X25519.Func(serverPublickey, serverPrivateKey);
+
+ // client key aggreement
+ byte[] clientSharedSecret = new byte[X25519.KeySize];
+ Assert.IsTrue(X25519.Func(clientSharedSecret, clientPrivateKey, serverPublickey));
+
+ // server key agreement
+ byte[] serverSharedSecret = new byte[X25519.KeySize];
+ Assert.IsTrue(X25519.Func(serverSharedSecret, serverPrivateKey, clientPublicKey));
+
+ CollectionAssert.AreEqual(clientSharedSecret, serverSharedSecret);
+ }
+ }
+
+ private struct TestVector
+ {
+ public byte[] In;
+ public byte[] Base;
+ public byte[] Expect;
+ }
+
+ private static readonly TestVector[] TestVectorData =
+ {
+ new TestVector {
+ In = new byte[]{0x66, 0x8f, 0xb9, 0xf7, 0x6a, 0xd9, 0x71, 0xc8, 0x1a, 0xc9, 0x0, 0x7, 0x1a, 0x15, 0x60, 0xbc, 0xe2, 0xca, 0x0, 0xca, 0xc7, 0xe6, 0x7a, 0xf9, 0x93, 0x48, 0x91, 0x37, 0x61, 0x43, 0x40, 0x14},
+ Base = new byte[]{0xdb, 0x5f, 0x32, 0xb7, 0xf8, 0x41, 0xe7, 0xa1, 0xa0, 0x9, 0x68, 0xef, 0xfd, 0xed, 0x12, 0x73, 0x5f, 0xc4, 0x7a, 0x3e, 0xb1, 0x3b, 0x57, 0x9a, 0xac, 0xad, 0xea, 0xe8, 0x9, 0x39, 0xa7, 0xdd},
+ Expect = new byte[]{0x9, 0xd, 0x85, 0xe5, 0x99, 0xea, 0x8e, 0x2b, 0xee, 0xb6, 0x13, 0x4, 0xd3, 0x7b, 0xe1, 0xe, 0xc5, 0xc9, 0x5, 0xf9, 0x92, 0x7d, 0x32, 0xf4, 0x2a, 0x9a, 0xa, 0xfb, 0x3e, 0xb, 0x40, 0x74},
+ },
+ new TestVector {
+ In = new byte[]{ 0x63, 0x66, 0x95, 0xe3, 0x4f, 0x75, 0xb9, 0xa2, 0x79, 0xc8, 0x70, 0x6f, 0xad, 0x12, 0x89, 0xf2, 0xc0, 0xb1, 0xe2, 0x2e, 0x16, 0xf8, 0xb8, 0x86, 0x17, 0x29, 0xc1, 0xa, 0x58, 0x29, 0x58, 0xaf},
+ Base = new byte[]{ 0x9, 0xd, 0x7, 0x1, 0xf8, 0xfd, 0xe2, 0x8f, 0x70, 0x4, 0x3b, 0x83, 0xf2, 0x34, 0x62, 0x25, 0x41, 0x9b, 0x18, 0xa7, 0xf2, 0x7e, 0x9e, 0x3d, 0x2b, 0xfd, 0x4, 0xe1, 0xf, 0x3d, 0x21, 0x3e},
+ Expect = new byte[]{ 0xbf, 0x26, 0xec, 0x7e, 0xc4, 0x13, 0x6, 0x17, 0x33, 0xd4, 0x40, 0x70, 0xea, 0x67, 0xca, 0xb0, 0x2a, 0x85, 0xdc, 0x1b, 0xe8, 0xcf, 0xe1, 0xff, 0x73, 0xd5, 0x41, 0xcc, 0x8, 0x32, 0x55, 0x6},
+ },
+ new TestVector {
+ In = new byte[]{ 0x73, 0x41, 0x81, 0xcd, 0x1a, 0x94, 0x6, 0x52, 0x2a, 0x56, 0xfe, 0x25, 0xe4, 0x3e, 0xcb, 0xf0, 0x29, 0x5d, 0xb5, 0xdd, 0xd0, 0x60, 0x9b, 0x3c, 0x2b, 0x4e, 0x79, 0xc0, 0x6f, 0x8b, 0xd4, 0x6d},
+ Base = new byte[]{ 0xf8, 0xa8, 0x42, 0x1c, 0x7d, 0x21, 0xa9, 0x2d, 0xb3, 0xed, 0xe9, 0x79, 0xe1, 0xfa, 0x6a, 0xcb, 0x6, 0x2b, 0x56, 0xb1, 0x88, 0x5c, 0x71, 0xc5, 0x11, 0x53, 0xcc, 0xb8, 0x80, 0xac, 0x73, 0x15},
+ Expect = new byte[]{ 0x11, 0x76, 0xd0, 0x16, 0x81, 0xf2, 0xcf, 0x92, 0x9d, 0xa2, 0xc7, 0xa3, 0xdf, 0x66, 0xb5, 0xd7, 0x72, 0x9f, 0xd4, 0x22, 0x22, 0x6f, 0xd6, 0x37, 0x42, 0x16, 0xbf, 0x7e, 0x2, 0xfd, 0xf, 0x62},
+ },
+ new TestVector {
+ In = new byte[]{ 0x1f, 0x70, 0x39, 0x1f, 0x6b, 0xa8, 0x58, 0x12, 0x94, 0x13, 0xbd, 0x80, 0x1b, 0x12, 0xac, 0xbf, 0x66, 0x23, 0x62, 0x82, 0x5c, 0xa2, 0x50, 0x9c, 0x81, 0x87, 0x59, 0xa, 0x2b, 0xe, 0x61, 0x72},
+ Base = new byte[]{ 0xd3, 0xea, 0xd0, 0x7a, 0x0, 0x8, 0xf4, 0x45, 0x2, 0xd5, 0x80, 0x8b, 0xff, 0xc8, 0x97, 0x9f, 0x25, 0xa8, 0x59, 0xd5, 0xad, 0xf4, 0x31, 0x2e, 0xa4, 0x87, 0x48, 0x9c, 0x30, 0xe0, 0x1b, 0x3b},
+ Expect = new byte[]{ 0xf8, 0x48, 0x2f, 0x2e, 0x9e, 0x58, 0xbb, 0x6, 0x7e, 0x86, 0xb2, 0x87, 0x24, 0xb3, 0xc0, 0xa3, 0xbb, 0xb5, 0x7, 0x3e, 0x4c, 0x6a, 0xcd, 0x93, 0xdf, 0x54, 0x5e, 0xff, 0xdb, 0xba, 0x50, 0x5f},
+ },
+ new TestVector {
+ In = new byte[]{ 0x3a, 0x7a, 0xe6, 0xcf, 0x8b, 0x88, 0x9d, 0x2b, 0x7a, 0x60, 0xa4, 0x70, 0xad, 0x6a, 0xd9, 0x99, 0x20, 0x6b, 0xf5, 0x7d, 0x90, 0x30, 0xdd, 0xf7, 0xf8, 0x68, 0xc, 0x8b, 0x1a, 0x64, 0x5d, 0xaa},
+ Base = new byte[]{ 0x4d, 0x25, 0x4c, 0x80, 0x83, 0xd8, 0x7f, 0x1a, 0x9b, 0x3e, 0xa7, 0x31, 0xef, 0xcf, 0xf8, 0xa6, 0xf2, 0x31, 0x2d, 0x6f, 0xed, 0x68, 0xe, 0xf8, 0x29, 0x18, 0x51, 0x61, 0xc8, 0xfc, 0x50, 0x60},
+ Expect = new byte[]{ 0x47, 0xb3, 0x56, 0xd5, 0x81, 0x8d, 0xe8, 0xef, 0xac, 0x77, 0x4b, 0x71, 0x4c, 0x42, 0xc4, 0x4b, 0xe6, 0x85, 0x23, 0xdd, 0x57, 0xdb, 0xd7, 0x39, 0x62, 0xd5, 0xa5, 0x26, 0x31, 0x87, 0x62, 0x37},
+ },
+ new TestVector {
+ In = new byte[]{ 0x20, 0x31, 0x61, 0xc3, 0x15, 0x9a, 0x87, 0x6a, 0x2b, 0xea, 0xec, 0x29, 0xd2, 0x42, 0x7f, 0xb0, 0xc7, 0xc3, 0xd, 0x38, 0x2c, 0xd0, 0x13, 0xd2, 0x7c, 0xc3, 0xd3, 0x93, 0xdb, 0xd, 0xaf, 0x6f},
+ Base = new byte[]{ 0x6a, 0xb9, 0x5d, 0x1a, 0xbe, 0x68, 0xc0, 0x9b, 0x0, 0x5c, 0x3d, 0xb9, 0x4, 0x2c, 0xc9, 0x1a, 0xc8, 0x49, 0xf7, 0xe9, 0x4a, 0x2a, 0x4a, 0x9b, 0x89, 0x36, 0x78, 0x97, 0xb, 0x7b, 0x95, 0xbf},
+ Expect = new byte[]{ 0x11, 0xed, 0xae, 0xdc, 0x95, 0xff, 0x78, 0xf5, 0x63, 0xa1, 0xc8, 0xf1, 0x55, 0x91, 0xc0, 0x71, 0xde, 0xa0, 0x92, 0xb4, 0xd7, 0xec, 0xaa, 0xc8, 0xe0, 0x38, 0x7b, 0x5a, 0x16, 0xc, 0x4e, 0x5d},
+ },
+ new TestVector {
+ In = new byte[]{ 0x13, 0xd6, 0x54, 0x91, 0xfe, 0x75, 0xf2, 0x3, 0xa0, 0x8, 0xb4, 0x41, 0x5a, 0xbc, 0x60, 0xd5, 0x32, 0xe6, 0x95, 0xdb, 0xd2, 0xf1, 0xe8, 0x3, 0xac, 0xcb, 0x34, 0xb2, 0xb7, 0x2c, 0x3d, 0x70},
+ Base = new byte[]{ 0x2e, 0x78, 0x4e, 0x4, 0xca, 0x0, 0x73, 0x33, 0x62, 0x56, 0xa8, 0x39, 0x25, 0x5e, 0xd2, 0xf7, 0xd4, 0x79, 0x6a, 0x64, 0xcd, 0xc3, 0x7f, 0x1e, 0xb0, 0xe5, 0xc4, 0xc8, 0xd1, 0xd1, 0xe0, 0xf5},
+ Expect = new byte[]{ 0x56, 0x3e, 0x8c, 0x9a, 0xda, 0xa7, 0xd7, 0x31, 0x1, 0xb0, 0xf2, 0xea, 0xd3, 0xca, 0xe1, 0xea, 0x5d, 0x8f, 0xcd, 0x5c, 0xd3, 0x60, 0x80, 0xbb, 0x8e, 0x6e, 0xc0, 0x3d, 0x61, 0x45, 0x9, 0x17},
+ },
+ new TestVector {
+ In = new byte[]{ 0x68, 0x6f, 0x7d, 0xa9, 0x3b, 0xf2, 0x68, 0xe5, 0x88, 0x6, 0x98, 0x31, 0xf0, 0x47, 0x16, 0x3f, 0x33, 0x58, 0x99, 0x89, 0xd0, 0x82, 0x6e, 0x98, 0x8, 0xfb, 0x67, 0x8e, 0xd5, 0x7e, 0x67, 0x49},
+ Base = new byte[]{ 0x8b, 0x54, 0x9b, 0x2d, 0xf6, 0x42, 0xd3, 0xb2, 0x5f, 0xe8, 0x38, 0xf, 0x8c, 0xc4, 0x37, 0x5f, 0x99, 0xb7, 0xbb, 0x4d, 0x27, 0x5f, 0x77, 0x9f, 0x3b, 0x7c, 0x81, 0xb8, 0xa2, 0xbb, 0xc1, 0x29},
+ Expect = new byte[]{ 0x1, 0x47, 0x69, 0x65, 0x42, 0x6b, 0x61, 0x71, 0x74, 0x9a, 0x8a, 0xdd, 0x92, 0x35, 0x2, 0x5c, 0xe5, 0xf5, 0x57, 0xfe, 0x40, 0x9, 0xf7, 0x39, 0x30, 0x44, 0xeb, 0xbb, 0x8a, 0xe9, 0x52, 0x79},
+ },
+ new TestVector {
+ In = new byte[]{ 0x82, 0xd6, 0x1c, 0xce, 0xdc, 0x80, 0x6a, 0x60, 0x60, 0xa3, 0x34, 0x9a, 0x5e, 0x87, 0xcb, 0xc7, 0xac, 0x11, 0x5e, 0x4f, 0x87, 0x77, 0x62, 0x50, 0xae, 0x25, 0x60, 0x98, 0xa7, 0xc4, 0x49, 0x59},
+ Base = new byte[]{ 0x8b, 0x6b, 0x9d, 0x8, 0xf6, 0x1f, 0xc9, 0x1f, 0xe8, 0xb3, 0x29, 0x53, 0xc4, 0x23, 0x40, 0xf0, 0x7, 0xb5, 0x71, 0xdc, 0xb0, 0xa5, 0x6d, 0x10, 0x72, 0x4e, 0xce, 0xf9, 0x95, 0xc, 0xfb, 0x25},
+ Expect = new byte[]{ 0x9c, 0x49, 0x94, 0x1f, 0x9c, 0x4f, 0x18, 0x71, 0xfa, 0x40, 0x91, 0xfe, 0xd7, 0x16, 0xd3, 0x49, 0x99, 0xc9, 0x52, 0x34, 0xed, 0xf2, 0xfd, 0xfb, 0xa6, 0xd1, 0x4a, 0x5a, 0xfe, 0x9e, 0x5, 0x58},
+ },
+ new TestVector {
+ In = new byte[]{ 0x7d, 0xc7, 0x64, 0x4, 0x83, 0x13, 0x97, 0xd5, 0x88, 0x4f, 0xdf, 0x6f, 0x97, 0xe1, 0x74, 0x4c, 0x9e, 0xb1, 0x18, 0xa3, 0x1a, 0x7b, 0x23, 0xf8, 0xd7, 0x9f, 0x48, 0xce, 0x9c, 0xad, 0x15, 0x4b},
+ Base = new byte[]{ 0x1a, 0xcd, 0x29, 0x27, 0x84, 0xf4, 0x79, 0x19, 0xd4, 0x55, 0xf8, 0x87, 0x44, 0x83, 0x58, 0x61, 0xb, 0xb9, 0x45, 0x96, 0x70, 0xeb, 0x99, 0xde, 0xe4, 0x60, 0x5, 0xf6, 0x89, 0xca, 0x5f, 0xb6},
+ Expect = new byte[]{ 0x0, 0xf4, 0x3c, 0x2, 0x2e, 0x94, 0xea, 0x38, 0x19, 0xb0, 0x36, 0xae, 0x2b, 0x36, 0xb2, 0xa7, 0x61, 0x36, 0xaf, 0x62, 0x8a, 0x75, 0x1f, 0xe5, 0xd0, 0x1e, 0x3, 0xd, 0x44, 0x25, 0x88, 0x59},
+ },
+ new TestVector {
+ In = new byte[]{ 0xfb, 0xc4, 0x51, 0x1d, 0x23, 0xa6, 0x82, 0xae, 0x4e, 0xfd, 0x8, 0xc8, 0x17, 0x9c, 0x1c, 0x6, 0x7f, 0x9c, 0x8b, 0xe7, 0x9b, 0xbc, 0x4e, 0xff, 0x5c, 0xe2, 0x96, 0xc6, 0xbc, 0x1f, 0xf4, 0x45},
+ Base = new byte[]{ 0x55, 0xca, 0xff, 0x21, 0x81, 0xf2, 0x13, 0x6b, 0xe, 0xd0, 0xe1, 0xe2, 0x99, 0x44, 0x48, 0xe1, 0x6c, 0xc9, 0x70, 0x64, 0x6a, 0x98, 0x3d, 0x14, 0xd, 0xc4, 0xea, 0xb3, 0xd9, 0x4c, 0x28, 0x4e},
+ Expect = new byte[]{ 0xae, 0x39, 0xd8, 0x16, 0x53, 0x23, 0x45, 0x79, 0x4d, 0x26, 0x91, 0xe0, 0x80, 0x1c, 0xaa, 0x52, 0x5f, 0xc3, 0x63, 0x4d, 0x40, 0x2c, 0xe9, 0x58, 0xb, 0x33, 0x38, 0xb4, 0x6f, 0x8b, 0xb9, 0x72},
+ },
+ new TestVector {
+ In = new byte[]{ 0x4e, 0x6, 0xc, 0xe1, 0xc, 0xeb, 0xf0, 0x95, 0x9, 0x87, 0x16, 0xc8, 0x66, 0x19, 0xeb, 0x9f, 0x7d, 0xf6, 0x65, 0x24, 0x69, 0x8b, 0xa7, 0x98, 0x8c, 0x3b, 0x90, 0x95, 0xd9, 0xf5, 0x1, 0x34},
+ Base = new byte[]{ 0x57, 0x73, 0x3f, 0x2d, 0x86, 0x96, 0x90, 0xd0, 0xd2, 0xed, 0xae, 0xc9, 0x52, 0x3d, 0xaa, 0x2d, 0xa9, 0x54, 0x45, 0xf4, 0x4f, 0x57, 0x83, 0xc1, 0xfa, 0xec, 0x6c, 0x3a, 0x98, 0x28, 0x18, 0xf3},
+ Expect = new byte[]{ 0xa6, 0x1e, 0x74, 0x55, 0x2c, 0xce, 0x75, 0xf5, 0xe9, 0x72, 0xe4, 0x24, 0xf2, 0xcc, 0xb0, 0x9c, 0x83, 0xbc, 0x1b, 0x67, 0x1, 0x47, 0x48, 0xf0, 0x2c, 0x37, 0x1a, 0x20, 0x9e, 0xf2, 0xfb, 0x2c},
+ },
+ new TestVector {
+ In = new byte[]{ 0x5c, 0x49, 0x2c, 0xba, 0x2c, 0xc8, 0x92, 0x48, 0x8a, 0x9c, 0xeb, 0x91, 0x86, 0xc2, 0xaa, 0xc2, 0x2f, 0x1, 0x5b, 0xf3, 0xef, 0x8d, 0x3e, 0xcc, 0x9c, 0x41, 0x76, 0x97, 0x62, 0x61, 0xaa, 0xb1},
+ Base = new byte[]{ 0x67, 0x97, 0xc2, 0xe7, 0xdc, 0x92, 0xcc, 0xbe, 0x7c, 0x5, 0x6b, 0xec, 0x35, 0xa, 0xb6, 0xd3, 0xbd, 0x2a, 0x2c, 0x6b, 0xc5, 0xa8, 0x7, 0xbb, 0xca, 0xe1, 0xf6, 0xc2, 0xaf, 0x80, 0x36, 0x44},
+ Expect = new byte[]{ 0xfc, 0xf3, 0x7, 0xdf, 0xbc, 0x19, 0x2, 0xb, 0x28, 0xa6, 0x61, 0x8c, 0x6c, 0x62, 0x2f, 0x31, 0x7e, 0x45, 0x96, 0x7d, 0xac, 0xf4, 0xae, 0x4a, 0xa, 0x69, 0x9a, 0x10, 0x76, 0x9f, 0xde, 0x14},
+ },
+ new TestVector {
+ In = new byte[]{ 0xea, 0x33, 0x34, 0x92, 0x96, 0x5, 0x5a, 0x4e, 0x8b, 0x19, 0x2e, 0x3c, 0x23, 0xc5, 0xf4, 0xc8, 0x44, 0x28, 0x2a, 0x3b, 0xfc, 0x19, 0xec, 0xc9, 0xdc, 0x64, 0x6a, 0x42, 0xc3, 0x8d, 0xc2, 0x48},
+ Base = new byte[]{ 0x2c, 0x75, 0xd8, 0x51, 0x42, 0xec, 0xad, 0x3e, 0x69, 0x44, 0x70, 0x4, 0x54, 0xc, 0x1c, 0x23, 0x54, 0x8f, 0xc8, 0xf4, 0x86, 0x25, 0x1b, 0x8a, 0x19, 0x46, 0x3f, 0x3d, 0xf6, 0xf8, 0xac, 0x61},
+ Expect = new byte[]{ 0x5d, 0xca, 0xb6, 0x89, 0x73, 0xf9, 0x5b, 0xd3, 0xae, 0x4b, 0x34, 0xfa, 0xb9, 0x49, 0xfb, 0x7f, 0xb1, 0x5a, 0xf1, 0xd8, 0xca, 0xe2, 0x8c, 0xd6, 0x99, 0xf9, 0xc1, 0xaa, 0x33, 0x37, 0x34, 0x2f},
+ },
+ new TestVector {
+ In = new byte[]{ 0x4f, 0x29, 0x79, 0xb1, 0xec, 0x86, 0x19, 0xe4, 0x5c, 0xa, 0xb, 0x2b, 0x52, 0x9, 0x34, 0x54, 0x1a, 0xb9, 0x44, 0x7, 0xb6, 0x4d, 0x19, 0xa, 0x76, 0xf3, 0x23, 0x14, 0xef, 0xe1, 0x84, 0xe7},
+ Base = new byte[]{ 0xf7, 0xca, 0xe1, 0x8d, 0x8d, 0x36, 0xa7, 0xf5, 0x61, 0x17, 0xb8, 0xb7, 0xe, 0x25, 0x52, 0x27, 0x7f, 0xfc, 0x99, 0xdf, 0x87, 0x56, 0xb5, 0xe1, 0x38, 0xbf, 0x63, 0x68, 0xbc, 0x87, 0xf7, 0x4c},
+ Expect = new byte[]{ 0xe4, 0xe6, 0x34, 0xeb, 0xb4, 0xfb, 0x66, 0x4f, 0xe8, 0xb2, 0xcf, 0xa1, 0x61, 0x5f, 0x0, 0xe6, 0x46, 0x6f, 0xff, 0x73, 0x2c, 0xe1, 0xf8, 0xa0, 0xc8, 0xd2, 0x72, 0x74, 0x31, 0xd1, 0x6f, 0x14},
+ },
+ new TestVector {
+ In = new byte[]{ 0xf5, 0xd8, 0xa9, 0x27, 0x90, 0x1d, 0x4f, 0xa4, 0x24, 0x90, 0x86, 0xb7, 0xff, 0xec, 0x24, 0xf5, 0x29, 0x7d, 0x80, 0x11, 0x8e, 0x4a, 0xc9, 0xd3, 0xfc, 0x9a, 0x82, 0x37, 0x95, 0x1e, 0x3b, 0x7f},
+ Base = new byte[]{ 0x3c, 0x23, 0x5e, 0xdc, 0x2, 0xf9, 0x11, 0x56, 0x41, 0xdb, 0xf5, 0x16, 0xd5, 0xde, 0x8a, 0x73, 0x5d, 0x6e, 0x53, 0xe2, 0x2a, 0xa2, 0xac, 0x14, 0x36, 0x56, 0x4, 0x5f, 0xf2, 0xe9, 0x52, 0x49},
+ Expect = new byte[]{ 0xab, 0x95, 0x15, 0xab, 0x14, 0xaf, 0x9d, 0x27, 0xe, 0x1d, 0xae, 0xc, 0x56, 0x80, 0xcb, 0xc8, 0x88, 0xb, 0xd8, 0xa8, 0xe7, 0xeb, 0x67, 0xb4, 0xda, 0x42, 0xa6, 0x61, 0x96, 0x1e, 0xfc, 0xb},
+ },
+ };
+ }
+
+ [TestClass]
+ public class X25519FieldTests
+ {
+ private readonly byte[] A = {0x21, 0xDD, 0xB0, 0x43, 0xCF, 0xB2, 0xB3, 0xFE, 0xC4, 0xCC, 0xA3, 0x8B, 0xBA, 0x3D, 0xE1, 0x92, 0xDF, 0xEA, 0x85, 0xCE, 0x2B, 0x4A, 0xD8, 0x44, 0x95, 0xAA, 0xB1, 0x3A, 0x5B, 0x62, 0x87, 0x8E};
+ private readonly byte[] B = {0x22, 0x8B, 0x2F, 0x48, 0x4B, 0x86, 0xAC, 0x9C, 0xA4, 0x7B, 0x64, 0xC4, 0x62, 0x76, 0x34, 0x7C, 0x67, 0xBD, 0x59, 0x6F, 0x8D, 0x18, 0x41, 0x4D, 0x96, 0x31, 0xA5, 0x5B, 0x3B, 0xA5, 0x7E, 0xC7};
+
+ [TestMethod]
+ public void Zero()
+ {
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement fe = X25519.FieldElement.Zero();
+ fe.CopyTo(actual);
+
+ byte[] expected = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void One()
+ {
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement fe = X25519.FieldElement.One();
+ fe.CopyTo(actual);
+
+ byte[] expected = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Add()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Add(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x43, 0x68, 0xE0, 0x8B, 0x1A, 0x39, 0x60, 0x9B, 0x69, 0x48, 0x08, 0x50, 0x1D, 0xB4, 0x15, 0x0F, 0x47, 0xA8, 0xDF, 0x3D, 0xB9, 0x62, 0x19, 0x92, 0x2B, 0xDC, 0x56, 0x96, 0x96, 0x07, 0x06, 0x56};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Sub()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Sub(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0xEC, 0x51, 0x81, 0xFB, 0x83, 0x2C, 0x07, 0x62, 0x20, 0x51, 0x3F, 0xC7, 0x57, 0xC7, 0xAC, 0x16, 0x78, 0x2D, 0x2C, 0x5F, 0x9E, 0x31, 0x97, 0xF7, 0xFE, 0x78, 0x0C, 0xDF, 0x1F, 0xBD, 0x08, 0x47};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Multiply()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement b = X25519.FieldElement.FromBytes(B);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Multiply(ref c, ref a, ref b);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x1E, 0xBE, 0xBD, 0xE0, 0xEC, 0xB1, 0x3C, 0xDB, 0x50, 0x6E, 0xD6, 0x50, 0x02, 0x1A, 0x59, 0x99, 0xC1, 0xC0, 0xFC, 0xE0, 0xBF, 0xDB, 0x64, 0xB0, 0x3E, 0xB3, 0x2D, 0x43, 0x8B, 0x66, 0x43, 0x3C};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Square()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Square(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0xAE, 0xB2, 0x22, 0xD4, 0x72, 0xF7, 0xF4, 0x09, 0xBB, 0x9A, 0xA9, 0x99, 0xEB, 0x7F, 0xC4, 0xE1, 0x4C, 0x0A, 0x53, 0xEB, 0x3C, 0xFF, 0x5C, 0xE2, 0xF6, 0x92, 0x46, 0x53, 0x29, 0xE1, 0x5D, 0x7A};
+ CollectionAssert.AreEqual(expected, actual);
+
+
+ a = X25519.FieldElement.FromBytes(A);
+ X25519.FieldElement.Square(ref a, ref a);
+ a.CopyTo(actual);
+
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Multiply121666()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Multiply121666(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x65, 0x3E, 0xE9, 0x9D, 0x08, 0xAC, 0x1A, 0x17, 0x61, 0x4F, 0x2C, 0xED, 0x30, 0x0B, 0x9B, 0xCB, 0x2B, 0x63, 0x53, 0xB9, 0x7D, 0x67, 0x62, 0x11, 0x39, 0xF1, 0x50, 0xC9, 0x6C, 0xA1, 0x66, 0x72};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+
+ [TestMethod]
+ public void Invert()
+ {
+ X25519.FieldElement a = X25519.FieldElement.FromBytes(A);
+
+ byte[] actual = new byte[X25519.KeySize];
+ X25519.FieldElement c = new X25519.FieldElement();
+ X25519.FieldElement.Invert(ref c, ref a);
+ c.CopyTo(actual);
+
+ byte[] expected = {0x8E, 0x66, 0x2F, 0x60, 0xFC, 0xCD, 0x3A, 0x11, 0x36, 0xF5, 0xD9, 0xE6, 0x94, 0x28, 0x04, 0x2A, 0x6B, 0x5D, 0xC4, 0x72, 0x82, 0x30, 0xF3, 0x09, 0xC0, 0x24, 0xDE, 0xCD, 0x60, 0x3F, 0x5D, 0x17};
+ CollectionAssert.AreEqual(expected, actual);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs
new file mode 100644
index 0000000..8e71f2d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/AesGcmRecordProtectedTests.cs
@@ -0,0 +1,205 @@
+using Hazel.Dtls;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+using System.Text;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class AesGcmRecordProtectedTests
+ {
+ private readonly ByteSpan masterSecret;
+ private readonly ByteSpan serverRandom;
+ private readonly ByteSpan clientRandom;
+
+ private const string TestMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
+
+ public AesGcmRecordProtectedTests()
+ {
+ this.masterSecret = new byte[48];
+ this.serverRandom = new byte[32];
+ this.clientRandom = new byte[32];
+
+ using (RandomNumberGenerator random = RandomNumberGenerator.Create())
+ {
+ random.GetBytes(this.masterSecret.GetUnderlyingArray());
+ random.GetBytes(this.serverRandom.GetUnderlyingArray());
+ random.GetBytes(this.clientRandom.GetUnderlyingArray());
+ }
+ }
+
+ [TestMethod]
+ public void ServerCanEncryptAndDecryptData()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ recordProtection.EncryptServerPlaintext(encrypted, messageAsBytes, ref record);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ClientCanEncryptAndDecryptData()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ recordProtection.EncryptClientPlaintext(encrypted, messageAsBytes, ref record);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ServerDecryptionFailsWhenRecordModified()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record originalRecord = new Record();
+ originalRecord.ContentType = ContentType.ApplicationData;
+ originalRecord.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ originalRecord.Epoch = 1;
+ originalRecord.SequenceNumber = 124;
+ originalRecord.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[originalRecord.Length];
+ recordProtection.EncryptServerPlaintext(encrypted, messageAsBytes, ref originalRecord);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+
+ Record record = originalRecord;
+ record.ContentType = ContentType.Handshake;
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.Epoch++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.SequenceNumber++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDecryptionFailsWhenRecordModified()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ byte[] messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record originalRecord = new Record();
+ originalRecord.ContentType = ContentType.ApplicationData;
+ originalRecord.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ originalRecord.Epoch = 1;
+ originalRecord.SequenceNumber = 124;
+ originalRecord.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[originalRecord.Length];
+ recordProtection.EncryptClientPlaintext(encrypted, messageAsBytes, ref originalRecord);
+
+ ByteSpan plaintext = new byte[recordProtection.GetDecryptedSize(encrypted.Length)];
+
+ Record record = originalRecord;
+ record.ContentType = ContentType.Handshake;
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.Epoch++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+
+ record = originalRecord;
+ record.SequenceNumber++;
+ couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsFalse(couldDecrypt);
+ }
+ }
+
+ [TestMethod]
+ public void ServerEncryptionCanoverlap()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ ByteSpan messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ messageAsBytes.CopyTo(encrypted);
+ recordProtection.EncryptServerPlaintext(encrypted, encrypted.Slice(0, messageAsBytes.Length), ref record);
+
+ ByteSpan plaintext = encrypted.Slice(0, recordProtection.GetDecryptedSize(record.Length));
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromServer(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+
+ [TestMethod]
+ public void ClientEncryptionCanoverlap()
+ {
+ using (Aes128GcmRecordProtection recordProtection = new Aes128GcmRecordProtection(this.masterSecret, this.serverRandom, this.clientRandom))
+ {
+ ByteSpan messageAsBytes = Encoding.UTF8.GetBytes(TestMessage);
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 124;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(messageAsBytes.Length);
+
+ ByteSpan encrypted = new byte[record.Length];
+ messageAsBytes.CopyTo(encrypted);
+ recordProtection.EncryptClientPlaintext(encrypted, encrypted.Slice(0, messageAsBytes.Length), ref record);
+
+ ByteSpan plaintext = encrypted.Slice(0, recordProtection.GetDecryptedSize(record.Length));
+ bool couldDecrypt = recordProtection.DecryptCiphertextFromClient(plaintext, encrypted, ref record);
+ Assert.IsTrue(couldDecrypt);
+ Assert.AreEqual(messageAsBytes.Length, plaintext.Length);
+ Assert.AreEqual(TestMessage, Encoding.UTF8.GetString(plaintext.GetUnderlyingArray(), plaintext.Offset, plaintext.Length));
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs
new file mode 100644
index 0000000..55e0c91
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/DtlsConnectionTests.cs
@@ -0,0 +1,1192 @@
+using Hazel.Dtls;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class DtlsConnectionTests
+ {
+ // Created with command line
+ // openssl req -newkey rsa:2048 -nodes -keyout key.pem -x509 -days 100000 -out certificate.pem
+ const string TestCertificate =
+@"-----BEGIN CERTIFICATE-----
+MIIDbTCCAlWgAwIBAgIUREHeZ36f23eBFQ1T3sJsBwHlSBEwDQYJKoZIhvcNAQEL
+BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
+GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMTAyMDIxNDE4MTBaGA8yMjk0
+MTExODE0MTgxMFowRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
+ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
+AQEBBQADggEPADCCAQoCggEBAMeHCR6Y6GFwH7ZnouxPLmqyCIJSCcfaGIuBU3k+
+MG2ZyXKhhhwclL8arx5x1cGmQFvPm5wXGKSiLFChj+bW5XN7xBAc5e9KVBCEabrr
+BY+X9r0a421Yjqn4F47IA2sQ6OygnttYIt0pgeEoQZhGvmc2ZfkELkptIHMavIsx
+B/R0tYgtquruWveIWMtr4O/AuPxkH750SO1OxwU8gj6QXSqskrxvhl9GBzAwBKaF
+W6t7yjR7eFqaGh7B55p4t5zrfYKCVgeyj5Yzr/xdvv3Q3H+0pex+JTMWrpsTaavq
+F2RZYbpTOofuiTwdWbAHnXW1aFSCCIrEdEs9X2FxB73V0fcCAwEAAaNTMFEwHQYD
+VR0OBBYEFETIkxnzoLXO2GcEgxTZgN8ypKowMB8GA1UdIwQYMBaAFETIkxnzoLXO
+2GcEgxTZgN8ypKowMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEB
+ACZl7WQec9xLTK0paBIkVUqZKucDCXQH0JC7z4ENbiRtQvWQm6xhAlDo8Tr8oUzj
+0/lft/g6wIo8dJ4jZ/iCSHnKz8qO80Gs/x5NISe9A/8Us1kq8y4nO40QW6xtQMH7
+j74pcfsGKDCaMFSQZnSc93a3ZMEuVPxdI5+qsvFIeC9xxRHUNo245eLqsJAe8s1c
+22Uoeu3gepozrPcIPAHADGr/CFp1HLkg9nFrTcatlNAF/N0PmLjmk/NIx/8h7n7Q
+5vapNkhcyCHsW8XB5ulKmF88QZ5BdvPmtSey0t/n8ru98615G5Wb4TS2MaprzYL3
+5ACeQOohFzevcQrEjjzkZAI=
+-----END CERTIFICATE-----
+";
+
+ const string TestPrivateKey =
+@"-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDHhwkemOhhcB+2
+Z6LsTy5qsgiCUgnH2hiLgVN5PjBtmclyoYYcHJS/Gq8ecdXBpkBbz5ucFxikoixQ
+oY/m1uVze8QQHOXvSlQQhGm66wWPl/a9GuNtWI6p+BeOyANrEOjsoJ7bWCLdKYHh
+KEGYRr5nNmX5BC5KbSBzGryLMQf0dLWILarq7lr3iFjLa+DvwLj8ZB++dEjtTscF
+PII+kF0qrJK8b4ZfRgcwMASmhVure8o0e3hamhoeweeaeLec632CglYHso+WM6/8
+Xb790Nx/tKXsfiUzFq6bE2mr6hdkWWG6UzqH7ok8HVmwB511tWhUggiKxHRLPV9h
+cQe91dH3AgMBAAECggEAVq+jVajHJTYqgPwLu7E3EGHi8oOj/jESAuIgGwfa0HNF
+I0lr06DTOyfjt01ruiN5yKmtCKa8LSLMMAfRVlA9BexapUl42HqphTeSHARpuRYj
+u8sHzgTwjoXb7kuVuJlzKQMroU5sbzvOUr1Dql3p8TugGA0p82nv9DJEghC+TQT2
+GnCKhsnmwE7lb8h0z3G3yxdv3yZE0X6oFzllBGCb1O6cwsDeYsv+SyjnyUwJROGz
+/VkC1+B48ALm4DhA5QIUoaRhO8vaCa7dacQTkXw1hVdLcaS9slIdXxbb9GbJvI0c
+baqimIkE02VUUpmlOIKUpf1sRXy1aJFpDSvWsTNLaQKBgQD8TrcVUF7oOhX5hQer
+qfNDFPvCBiWlT+8tnJraauaD1sJLD5jpRWPDu5dZ96CSZVpD1E3GQm+x58CSUknG
+AUHyHfEHTpzx7elVeUj7gianmVu8r9mVKtErwPLDJ4AUhMJjEPX2Plmh9GgFck78
+s2gfIvxdI+znvkH9JwGBznTIRQKBgQDKcpO7wiu025ZyPWR2p+qUh2ZlvBBr/6rg
+GxbFE5VraIS6zSDVOcxjPLc1pVZ/vM2WGbda0eziLpvsyauTXMCueBiNBRWZb5E4
+NK81IgbgZc4VWN9xA00cfOzO4Bjt6990BdOxiQQ1XOz1cN8DFTfsA81qR+nIne58
+LhL0DmFLCwKBgCwsU92FbrhVwxcmdUtWu+JYwCMeFGU283cW3f2zjZwzc1zU5D6j
+CW5xX3Q+6Hv5Bq6tcthtNUT+gDad9ZCXE8ah+1r+Jngs4Rc33tE53i6lqOwGFaAK
+GQkCBP6p4cC15ZqWk5mDHQo/0h5x/uY7OtWIuIpOCeIg60i5FYh2bvfJAoGAPQ7t
+i7V2ZSfNaksl37upPn7P3WMpOMl1if3hkjLj3+84CPcRLf4urMeFIkLpocEZ6Gl9
+KYEjBtyz3mi8vMc+veAu12lvKEXD8MXDCi1nEYri6wFQ8s7iFPOAoKxqGGgJjv6q
+6GLAyC9ssGIIgO+HXEGRVLq3wfAQG5fx03X61h0CgYEAiz3f8xiIR6PC4Gn5iMWu
+wmIDk3EnxdaA+7AwN/M037jbmKfzLxA1n8jXYM+to4SJx8Fxo7MD5EDhq6UoYmPU
+tGe4Ite2N9jzxG7xQrVuIx6Cg4t+E7uZ1eZuhbQ1WpqCXPIFOtXuc4szXfwD4Z+p
+IsdbLCwHYD3GVgk/D7NVxyU=
+-----END PRIVATE KEY-----
+";
+ private static X509Certificate2 GetCertificateForServer()
+ {
+ RSA privateKey = Utils.DecodeRSAKeyFromPEM(TestPrivateKey);
+ return new X509Certificate2(Utils.DecodePEM(TestCertificate)).CopyWithPrivateKey(privateKey);
+ }
+
+ private static X509Certificate2Collection GetCertificateForClient()
+ {
+ X509Certificate2 publicCertificate = new X509Certificate2(Utils.DecodePEM(TestCertificate));
+
+ X509Certificate2Collection clientCertificates = new X509Certificate2Collection();
+ clientCertificates.Add(publicCertificate);
+ return clientCertificates;
+ }
+
+ protected DtlsConnectionListener CreateListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ DtlsConnectionListener listener = new DtlsConnectionListener(2, endPoint, logger, ipMode);
+ listener.SetCertificate(GetCertificateForServer());
+ return listener;
+
+ }
+
+ protected DtlsUnityConnection CreateConnection(IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ DtlsUnityConnection connection = new DtlsUnityConnection(logger, endPoint, ipMode);
+ connection.SetValidServerCertificates(GetCertificateForClient());
+ return connection;
+ }
+
+ [TestMethod]
+ public void DtlsServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (var listener = (DtlsConnectionListener)CreateListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (var connection = CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ Assert.AreEqual(0, listener.PeerCount);
+ }
+ }
+
+ class MalformedDTLSListener : DtlsConnectionListener
+ {
+ public MalformedDTLSListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ : base(numWorkers, endPoint, logger, ipMode)
+ {
+ }
+
+ public void InjectPacket(ByteSpan packet, IPEndPoint peerAddress, ConnectionId connectionId)
+ {
+ MessageReader reader = MessageReader.GetSized(packet.Length);
+ reader.Length = packet.Length;
+ Array.Copy(packet.GetUnderlyingArray(), packet.Offset, reader.Buffer, reader.Offset, packet.Length);
+
+ base.ProcessIncomingMessageFromOtherThread(reader, peerAddress, connectionId);
+ }
+ }
+
+ class MalformedDTLSClient : DtlsUnityConnection
+ {
+ public Func<ClientHello, ByteSpan, ByteSpan> EncodeClientHelloCallback = Test_CompressionLengthOverrunClientHello;
+
+ public MalformedDTLSClient(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+
+ }
+
+ protected override void SendClientHello(bool isResend)
+ {
+ Test_SendClientHello(EncodeClientHelloCallback);
+ }
+
+ public static ByteSpan Test_CompressionLengthOverrunClientHello(ClientHello clientHello, ByteSpan writer)
+ {
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)ProtocolVersion.DTLS1_2);
+ writer = writer.Slice(2);
+
+ clientHello.Random.CopyTo(writer);
+ writer = writer.Slice(Hazel.Dtls.Random.Size);
+
+ // Do not encode session ids
+ writer[0] = (byte)0;
+ writer = writer.Slice(1);
+
+ writer[0] = (byte)clientHello.Cookie.Length;
+ clientHello.Cookie.CopyTo(writer.Slice(1));
+ writer = writer.Slice(1 + clientHello.Cookie.Length);
+
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)clientHello.CipherSuites.Length);
+ clientHello.CipherSuites.CopyTo(writer.Slice(2));
+ writer = writer.Slice(2 + clientHello.CipherSuites.Length);
+
+ // ============ Here is the corruption. writer[0] should be 1. ============
+ writer[0] = 255;
+ writer[1] = (byte)CompressionMethod.Null;
+ writer = writer.Slice(2);
+
+ // Extensions size
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)(6 + clientHello.SupportedCurves.Length));
+ writer = writer.Slice(2);
+
+ // Supported curves extension
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)ExtensionType.EllipticCurves);
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)(2 + clientHello.SupportedCurves.Length), 2);
+ ByteSpanBigEndianExtensions.WriteBigEndian16(writer, (ushort)clientHello.SupportedCurves.Length, 4);
+ clientHello.SupportedCurves.CopyTo(writer.Slice(6));
+
+ return writer;
+ }
+ }
+
+ [TestMethod]
+ public void TestMalformedApplicationData()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ IPEndPoint connectionEndPoint = ep;
+ DtlsConnectionListener.ConnectionId connectionId = new ThreadLimitedUdpConnectionListener.ConnectionId();
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (MalformedDTLSListener listener = new MalformedDTLSListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger(), ep))
+ {
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ connectionEndPoint = evt.Connection.EndPoint;
+ connectionId = ((ThreadLimitedUdpServerConnection)evt.Connection).ConnectionId;
+
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ ByteSpan data = new byte[5] { 0x01, 0x02, 0x03, 0x04, 0x05 };
+
+ Record record = new Record();
+ record.ContentType = ContentType.ApplicationData;
+ record.ProtocolVersion = ProtocolVersion.DTLS1_2;
+ record.Epoch = 1;
+ record.SequenceNumber = 10;
+ record.Length = (ushort)data.Length;
+
+ ByteSpan encoded = new byte[Record.Size + data.Length];
+ record.Encode(encoded);
+ data.CopyTo(encoded.Slice(Record.Size));
+
+ listener.InjectPacket(encoded, connectionEndPoint, connectionId);
+
+ // wait for the client to disconnect
+ listener.Dispose();
+ signal.WaitOne(100);
+ }
+ }
+
+ [TestMethod]
+ public void TestMalformedConnectionData()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ IPEndPoint connectionEndPoint = ep;
+ DtlsConnectionListener.ConnectionId connectionId = new ThreadLimitedUdpConnectionListener.ConnectionId();
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ using (MalformedDTLSClient connection = new MalformedDTLSClient(new TestLogger(), ep))
+ {
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ connectionEndPoint = evt.Connection.EndPoint;
+ connectionId = ((ThreadLimitedUdpServerConnection)evt.Connection).ConnectionId;
+
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ Assert.IsTrue(listener.ReceiveThreadRunning, "Listener should be able to handle a malformed hello packet");
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+
+ Assert.AreEqual(0, listener.PeerCount);
+
+ // wait for the client to disconnect
+ listener.Dispose();
+ signal.WaitOne(100);
+ }
+ }
+
+
+ [TestMethod]
+ public void TestReorderedHandshakePacketsConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ var logger = new TestLogger("Throttle");
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint, logger))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // HelloVerifyRequest
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, Server Certificate (Fragment)
+ // Server Cert
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Swap(0, 1));
+ listenerToConnectionThrottle.Release(3);
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ // Same flight, let's swap the ServerKeyExchange to the front
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Swap(0, 2));
+ listenerToConnectionThrottle.Release(3);
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ // Same flight, no swap we do matters as long as the ServerKeyExchange gets through.
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.ReorderPacketsForLocal(list => list.Reverse());
+
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(1);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+
+ [TestMethod]
+ public void TestResentClientHelloConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ var logger = new TestLogger("Throttle");
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint, logger))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // Trigger resend of HelloVerifyRequest
+ capture.DiscardPacketForLocal();
+
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, ServerCertificate
+ // ServerCertificate
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ listenerToConnectionThrottle.Release(3);
+
+ // Trigger a resend of ServerKeyExchange, ServerHelloDone
+ capture.DiscardPacketForLocal();
+
+ // From here, flush everything. We recover or not.
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(1);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void TestResentServerHelloConnects()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger("Server")))
+ using (DtlsUnityConnection connection = new DtlsUnityConnection(new TestLogger("Client "), captureEndPoint))
+ {
+ Semaphore listenerToConnectionThrottle = new Semaphore(0, int.MaxValue);
+ capture.SendToLocalSemaphore = listenerToConnectionThrottle;
+ Thread throttleThread = new Thread(() => {
+ // HelloVerifyRequest
+ capture.AssertPacketsToLocalCountEquals(1);
+ listenerToConnectionThrottle.Release(1);
+
+ // ServerHello, Server Certificate
+ // Server Certificate
+ // ServerKeyExchange, ServerHelloDone
+ capture.AssertPacketsToLocalCountEquals(3);
+ capture.DiscardPacketForLocal();
+ listenerToConnectionThrottle.Release(2);
+
+ // Wait for the resends and recover
+ capture.AssertPacketsToLocalCountEquals(3);
+
+ capture.SendToLocalSemaphore = null;
+ listenerToConnectionThrottle.Release(3);
+ });
+ throttleThread.Start();
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void TestConnectionSuccessAfterClientKeyExchangeFlightDropped()
+ {
+ IPEndPoint captureEndPoint = new IPEndPoint(IPAddress.Loopback, 27511);
+ IPEndPoint listenerEndPoint = new IPEndPoint(IPAddress.Loopback, 27510);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ Semaphore signal = new Semaphore(0, int.MaxValue);
+
+ using (SocketCapture capture = new SocketCapture(captureEndPoint, listenerEndPoint))
+ using (DtlsConnectionListener listener = new DtlsConnectionListener(2, new IPEndPoint(IPAddress.Any, listenerEndPoint.Port), new TestLogger()))
+ using (TestDtlsHandshakeDropUnityConnection connection = new TestDtlsHandshakeDropUnityConnection(new TestLogger(), captureEndPoint))
+ {
+ connection.DropSendClientKeyExchangeFlightCount = 1;
+
+ listener.SetCertificate(GetCertificateForServer());
+ connection.SetValidServerCertificates(GetCertificateForClient());
+
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ signal.Release();
+ evt.Connection.Disconnected += (o, et) => {
+ serverDisconnected = true;
+ };
+ };
+ connection.Disconnected += (o, evt) => {
+ clientDisconnected = true;
+ signal.Release();
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ // wait for the client to connect
+ signal.WaitOne(10);
+
+ listener.Dispose();
+
+ // wait for the client to disconnect
+ signal.WaitOne(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void PingDisconnectClientTest()
+ {
+#if DEBUG
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 27510);
+ using (DtlsConnectionListener listener = (DtlsConnectionListener)CreateListener(2, new IPEndPoint(IPAddress.Any, ep.Port), new TestLogger()))
+ {
+ // Adjust the ping rate to end the test faster
+ listener.NewConnection += (evt) =>
+ {
+ var conn = (ThreadLimitedUdpServerConnection)evt.Connection;
+ conn.KeepAliveInterval = 100;
+ conn.MissingPingsUntilDisconnect = 3;
+ };
+
+ listener.Start();
+
+ for (int i = 0; i < 5; ++i)
+ {
+ using (DtlsUnityConnection connection = (DtlsUnityConnection)CreateConnection(ep, new TestLogger()))
+ {
+ connection.KeepAliveInterval = 100;
+ connection.MissingPingsUntilDisconnect = 3;
+ connection.Connect();
+
+ Thread.Sleep(10);
+
+ // After connecting, quietly stop responding to all messages to fake connection loss.
+ connection.TestDropRate = 1;
+
+ Thread.Sleep(500); //Enough time for ~3 keep alive packets
+
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ listener.DisconnectOldConnections(TimeSpan.FromMilliseconds(500), null);
+
+ Assert.AreEqual(0, listener.PeerCount, "All clients disconnected, peer count should be zero.");
+ }
+#else
+ Assert.Inconclusive("Only works in DEBUG");
+#endif
+ }
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("SERVER")))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger("CLIENT")))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDisposeDisconnectTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void DtlsHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void DtlsUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void DtlsIPv4ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+
+ [TestMethod]
+ public void DtlsSessionV0ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (DtlsUnityConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ connection.HazelSessionVersion = 0;
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+
+ private class MultipleClientHelloDtlsConnection : DtlsUnityConnection
+ {
+ public MultipleClientHelloDtlsConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+ }
+
+ protected override void SendClientHello(bool isRetransmit)
+ {
+ base.SendClientHello(isRetransmit);
+ base.SendClientHello(true);
+ }
+ }
+
+
+ private class MultipleClientKeyExchangeFlightDtlsConnection : DtlsUnityConnection
+ {
+ public MultipleClientKeyExchangeFlightDtlsConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+ }
+
+ protected override void SendClientKeyExchangeFlight(bool isRetransmit)
+ {
+ base.SendClientKeyExchangeFlight(isRetransmit);
+ base.SendClientKeyExchangeFlight(true);
+ base.SendClientKeyExchangeFlight(true);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (DtlsConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (MultipleClientHelloDtlsConnection client = new MultipleClientHelloDtlsConnection(new TestLogger("Client "), new IPEndPoint(IPAddress.Loopback, 4296), IPMode.IPv4))
+ {
+ client.SetValidServerCertificates(GetCertificateForClient());
+
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+ client.Connect(null, 1000);
+
+ Thread.Sleep(2000);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+ Assert.IsTrue(client.HandshakeComplete);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple ClientKeyExchange packets.
+ /// </summary>
+ [TestMethod]
+ public void HandshakeLikeAJerkTest()
+ {
+ using (DtlsConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (MultipleClientKeyExchangeFlightDtlsConnection client = new MultipleClientKeyExchangeFlightDtlsConnection(new TestLogger("Client "), new IPEndPoint(IPAddress.Loopback, 4296), IPMode.IPv4))
+ {
+ client.SetValidServerCertificates(GetCertificateForClient());
+
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+ client.Connect();
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+ Assert.IsTrue(client.HandshakeComplete);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener2 = this.CreateListener(4, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine($"Connection: {evt.Connection.EndPoint}");
+ };
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger()))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection2 = this.CreateConnection(new IPEndPoint(IPAddress.IPv6Loopback, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection2.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection2.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void DtlsIPv6ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsUnreliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsReliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsUnreliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void DtlsReliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 500 &&
+ connection.Statistics.TotalBytesSent <= 675,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread timeoutThread = new Thread(() =>
+ {
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+ mutex.Set();
+ });
+ timeoutThread.Start();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (var listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (var connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger("Client")))
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ mutex.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("Server")))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger("Client")))
+ {
+ SemaphoreSlim mutex = new SemaphoreSlim(0, 100);
+ ManualResetEventSlim serverMutex = new ManualResetEventSlim(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Release();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ mutex.Release();
+
+ // This has to be on a new thread because the client will go straight from Connecting to NotConnected
+ ThreadPool.QueueUserWorkItem(_ =>
+ {
+ serverMutex.Wait(500);
+ args.Connection.Disconnect("Testing");
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ serverMutex.Set();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(5000);
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs
new file mode 100644
index 0000000..f33cd14
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/TestDtlsHandshakeDropUnityConnection.cs
@@ -0,0 +1,27 @@
+using Hazel.Dtls;
+using System.Net;
+
+namespace Hazel.UnitTests.Dtls
+{
+ internal class TestDtlsHandshakeDropUnityConnection : DtlsUnityConnection
+ {
+ public int DropSendClientKeyExchangeFlightCount = 0;
+
+ public TestDtlsHandshakeDropUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4) : base(logger, remoteEndPoint, ipMode)
+ {
+
+ }
+
+ protected override bool DropClientKeyExchangeFlight()
+ {
+ if (DropSendClientKeyExchangeFlightCount > 0)
+ {
+ this.logger.WriteInfo($"Dropping SendClientKeyExchangeFlight");
+ --DropSendClientKeyExchangeFlightCount;
+ return true;
+ }
+
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs
new file mode 100644
index 0000000..33cd6dc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Dtls/X25519EcdheRsaSha256Tests.cs
@@ -0,0 +1,226 @@
+using Hazel.Dtls;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests.Dtls
+{
+ [TestClass]
+ public class X25519EcdheRsaSha256Tests
+ {
+ private readonly RandomNumberGenerator random = RandomNumberGenerator.Create();
+ private readonly RSA privateKey = RSA.Create();
+ private readonly RSA publicKey;
+
+ public X25519EcdheRsaSha256Tests()
+ {
+ RSAParameters keyParameters = this.privateKey.ExportParameters(false);
+ this.publicKey = RSA.Create();
+ this.publicKey.ImportParameters(keyParameters);
+ }
+
+ [TestMethod]
+ public void SmallServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize/2 > 1);
+
+ data = new byte[expectedSize/2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void LargeServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize * 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void RandomServerDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void SmallClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize / 2 > 1);
+
+ data = new byte[expectedSize / 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void LargeClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize * 2];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void RandomClientDataFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateClientMessageSize();
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ random.GetBytes(data);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyClientMessageAndGenerateSharedKey(sharedKey, data));
+ }
+ }
+
+ [TestMethod]
+ public void RandomSignatureFails()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ cipherSuite.EncodeServerKeyExchangeMessage(data, this.privateKey);
+ }
+
+ // overwrite signature with random data
+ byte[] randomSignature = new byte[this.privateKey.KeySize/8];
+ random.GetBytes(randomSignature);
+ new ByteSpan(randomSignature).CopyTo(new ByteSpan(data, data.Length - randomSignature.Length, randomSignature.Length));
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsFalse(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void VerifySignature()
+ {
+ byte[] data;
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = cipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ data = new byte[expectedSize];
+ cipherSuite.EncodeServerKeyExchangeMessage(data, this.privateKey);
+ }
+
+ using (X25519EcdheRsaSha256 cipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ byte[] sharedKey = new byte[cipherSuite.SharedKeySize()];
+ Assert.IsTrue(cipherSuite.VerifyServerMessageAndGenerateSharedKey(sharedKey, data, this.publicKey));
+ }
+ }
+
+ [TestMethod]
+ public void GeneratesSameSharedKey()
+ {
+ byte[] serverSharedSecret;
+ byte[] clientSharedSecret;
+
+ using (X25519EcdheRsaSha256 serverCipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ int expectedSize = serverCipherSuite.CalculateServerMessageSize(this.privateKey);
+ Assert.IsTrue(expectedSize > 0);
+
+ byte[] serverKeyExchangeMessage = new byte[expectedSize];
+ serverCipherSuite.EncodeServerKeyExchangeMessage(serverKeyExchangeMessage, this.privateKey);
+
+ byte[] clientKeyExchange;
+
+ using (X25519EcdheRsaSha256 clientCipherSuite = new X25519EcdheRsaSha256(this.random))
+ {
+ clientSharedSecret = new byte[clientCipherSuite.SharedKeySize()];
+ Assert.IsTrue(clientCipherSuite.VerifyServerMessageAndGenerateSharedKey(clientSharedSecret, serverKeyExchangeMessage, this.publicKey));
+
+ clientKeyExchange = new byte[clientCipherSuite.CalculateClientMessageSize()];
+ clientCipherSuite.EncodeClientKeyExchangeMessage(clientKeyExchange);
+ }
+
+ serverSharedSecret = new byte[serverCipherSuite.SharedKeySize()];
+ Assert.IsTrue(serverCipherSuite.VerifyClientMessageAndGenerateSharedKey(serverSharedSecret, clientKeyExchange));
+ }
+
+ CollectionAssert.AreEqual(serverSharedSecret, clientSharedSecret);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj b/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj
new file mode 100644
index 0000000..f39bede
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Hazel.UnitTests.csproj
@@ -0,0 +1,16 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <TargetFramework>net472</TargetFramework>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
+ <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
+ <PackageReference Include="MSTest.TestFramework" Version="2.2.8" />
+
+ <ProjectReference Include="..\Hazel\Hazel.csproj" />
+ <PackageReference Include="Portable.BouncyCastle" Version="1.9.0" />
+ </ItemGroup>
+
+</Project>
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs
new file mode 100644
index 0000000..6cf4ba8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/MessageReaderTests.cs
@@ -0,0 +1,689 @@
+using System;
+using System.IO;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class MessageReaderTests
+ {
+ [TestMethod]
+ public void ReadProperInt()
+ {
+ const int Test1 = int.MaxValue;
+ const int Test2 = int.MinValue;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(11, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(Test1, reader.ReadInt32());
+ Assert.AreEqual(Test2, reader.ReadInt32());
+ }
+
+ [TestMethod]
+ public void ReadProperBool()
+ {
+ const bool Test1 = true;
+ const bool Test2 = false;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(5, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadBoolean());
+ Assert.AreEqual(Test2, reader.ReadBoolean());
+
+ }
+
+ [TestMethod]
+ public void ReadProperString()
+ {
+ const string Test1 = "Hello";
+ string Test2 = new string(' ', 1024);
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.Write(string.Empty);
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadString());
+ Assert.AreEqual(Test2, reader.ReadString());
+ Assert.AreEqual(string.Empty, reader.ReadString());
+
+ }
+
+ [TestMethod]
+ public void ReadProperFloat()
+ {
+ const float Test1 = 12.34f;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.EndMessage();
+
+ Assert.AreEqual(7, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadSingle());
+ }
+
+ [TestMethod]
+ public void RemoveMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+ reader.Length = msg.Length;
+
+ var zero = reader.ReadMessage();
+
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+ two.RemoveMessage(three);
+
+ // Reader becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Unrealistic, but nice. Earlier data is not affected
+ Assert.AreEqual(Test0, zero.ReadByte());
+
+ // Continuing to read depth-first works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorksWithSendOptionNone()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ }
+
+ [TestMethod]
+ public void InsertMessageWithoutStartMessageInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.Write(TestInsert);
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ Assert.AreEqual(TestInsert, two.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWithMultipleMessagesInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+ const byte TestInsert2 = 77;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ writer.StartMessage(6);
+ writer.Write(TestInsert2);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ var insert2 = two.ReadMessage();
+ Assert.AreEqual(TestInsert2, insert2.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageMultipleInsertsWithoutReset()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte Test6 = 66;
+ const byte TestInsert = 77;
+ const byte TestInsert2 = 88;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(67);
+ msg.Write(Test6);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get(SendOption.Reliable);
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ MessageWriter writer2 = MessageWriter.Get(SendOption.Reliable);
+ writer2.StartMessage(6);
+ writer2.Write(TestInsert2);
+ writer2.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ // three becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Continuing to read works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = one.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ reader.InsertMessage(one, writer2);
+
+ var six = reader.ReadMessage();
+ Assert.AreEqual(Test6, six.ReadByte());
+ }
+
+ [TestMethod]
+ public void CopySubMessage()
+ {
+ const byte Test1 = 12;
+ const byte Test2 = 146;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+
+ msg.StartMessage(2);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ MessageReader handleMessage = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, handleMessage.Tag);
+
+ var parentReader = MessageReader.Get(handleMessage);
+
+ handleMessage.Recycle();
+ SetZero(handleMessage);
+
+ Assert.AreEqual(1, parentReader.Tag);
+
+ for (int i = 0; i < 5; ++i)
+ {
+
+ var reader = parentReader.ReadMessage();
+ Assert.AreEqual(2, reader.Tag);
+ Assert.AreEqual(Test1, reader.ReadByte());
+ Assert.AreEqual(Test2, reader.ReadByte());
+
+ var temp = parentReader;
+ parentReader = MessageReader.CopyMessageIntoParent(reader);
+
+ temp.Recycle();
+ SetZero(temp);
+ SetZero(reader);
+ }
+ }
+
+ [TestMethod]
+ public void ReadMessageLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(2);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, reader.Tag);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessage();
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub = reader.ReadMessage();
+ Assert.AreEqual(0, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ }
+
+ [TestMethod]
+ public void ReadMessageAsNewBufferLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(232);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+ Assert.AreEqual(1, reader.Tag);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual(2, sub.Tag);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub.Recycle();
+
+ sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(0, sub.Length);
+ Assert.AreEqual(232, sub.Tag);
+ sub.Recycle();
+ }
+
+ [TestMethod]
+ public void ReadStringProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ int DataLength = TestDataFromAPreviousPacket.Length + 1;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a string from
+ // an offset, it reads the packed int at that offset, treats it
+ // as a length and then proceeds to read the data that comes after
+ // it without any bound checks. This can be chained with something
+ // else to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+
+ // This will be our malicious "string length"
+ writer.WritePacked(DataLength);
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray(includeHeader: false);
+
+ // One extra byte for the MessageWriter header, one more for the malicious data
+ Assert.AreEqual(DataLength + 1, testData.Length);
+
+ var dut = MessageReader.Get(testData);
+
+ // If Length is short by even a byte, ReadString should obey that.
+ dut.Length--;
+
+ try
+ {
+ dut.ReadString();
+ Assert.Fail("ReadString is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void ReadMessageProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ // Extra 3 bytes for the length + tag header for ReadMessage.
+ int DataLength = TestDataFromAPreviousPacket.Length + 1 + 3;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a message, it
+ // reads the uint16 at that offset, treats it as a length without any bound checks.
+ // This can be allow a later ReadString or ReadBytes to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+
+ // This is the malicious length. No data in this message, so it should be zero.
+ writer.Write((ushort)1);
+ writer.Write((byte)0); // Tag
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray(includeHeader: false);
+
+ Assert.AreEqual(DataLength, testData.Length);
+
+ var outer = MessageReader.Get(testData);
+
+ // Length is just the malicious message header.
+ outer.Length = 3;
+
+ try
+ {
+ outer.ReadMessage();
+ Assert.Fail("ReadMessage is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void GetLittleEndian()
+ {
+ Assert.IsTrue(MessageWriter.IsLittleEndian());
+ }
+
+ private void SetZero(MessageReader reader)
+ {
+ for (int i = 0; i < reader.Buffer.Length; ++i)
+ reader.Buffer[i] = 0;
+ }
+ }
+
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs
new file mode 100644
index 0000000..b292a5d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/MessageWriterTests.cs
@@ -0,0 +1,213 @@
+using System;
+using System.IO;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class MessageWriterTests
+ {
+
+ [TestMethod]
+ public void CancelMessages()
+ {
+ var msg = new MessageWriter(128);
+
+ msg.StartMessage(1);
+ msg.Write(32);
+
+ msg.StartMessage(2);
+ msg.Write(2);
+ msg.CancelMessage();
+
+ Assert.AreEqual(7, msg.Length);
+ Assert.IsFalse(msg.HasBytes(7));
+
+ msg.CancelMessage();
+
+ Assert.AreEqual(0, msg.Length);
+ Assert.IsFalse(msg.HasBytes(1));
+ }
+
+ [TestMethod]
+ public void HasBytes()
+ {
+ var msg = new MessageWriter(128);
+
+ msg.StartMessage(1);
+ msg.Write(32);
+
+ msg.StartMessage(2);
+ msg.Write(2);
+ msg.EndMessage();
+
+ // Assert.AreEqual(7, msg.Length);
+ Assert.IsTrue(msg.HasBytes(7));
+ }
+
+ [TestMethod]
+ public void WriteProperInt()
+ {
+ const int Test1 = int.MaxValue;
+ const int Test2 = int.MinValue;
+
+ var msg = new MessageWriter(128);
+ msg.Write(Test1);
+ msg.Write(Test2);
+
+ Assert.AreEqual(8, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadInt32());
+ Assert.AreEqual(Test2, reader.ReadInt32());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperBool()
+ {
+ const bool Test1 = true;
+ const bool Test2 = false;
+
+ var msg = new MessageWriter(128);
+ msg.Write(Test1);
+ msg.Write(Test2);
+
+ Assert.AreEqual(2, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadBoolean());
+ Assert.AreEqual(Test2, reader.ReadBoolean());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperString()
+ {
+ const string Test1 = "Hello";
+ string Test2 = new string(' ', 1024);
+ var msg = new MessageWriter(2048);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.Write(string.Empty);
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadString());
+ Assert.AreEqual(Test2, reader.ReadString());
+ Assert.AreEqual(string.Empty, reader.ReadString());
+ }
+ }
+
+ [TestMethod]
+ public void WriteProperFloat()
+ {
+ const float Test1 = 12.34f;
+
+ var msg = new MessageWriter(2048);
+ msg.Write(Test1);
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(Test1, reader.ReadSingle());
+ }
+ }
+
+ [TestMethod]
+ public void WritePackedUint()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.WritePacked(8u);
+ msg.WritePacked(250u);
+ msg.WritePacked(68000u);
+ msg.EndMessage();
+
+ Assert.AreEqual(3 + 1 + 2 + 3, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(8u, reader.ReadPackedUInt32());
+ Assert.AreEqual(250u, reader.ReadPackedUInt32());
+ Assert.AreEqual(68000u, reader.ReadPackedUInt32());
+ }
+
+ [TestMethod]
+ public void WritePackedInt()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.WritePacked(8);
+ msg.WritePacked(250);
+ msg.WritePacked(68000);
+ msg.WritePacked(60168000);
+ msg.WritePacked(-68000);
+ msg.WritePacked(-250);
+ msg.WritePacked(-8);
+
+ msg.WritePacked(0);
+ msg.WritePacked(-1);
+ msg.WritePacked(int.MinValue);
+ msg.WritePacked(int.MaxValue);
+ msg.EndMessage();
+
+ Assert.AreEqual(3 + 1 + 2 + 3 + 4 + 5 + 5 + 5 + 1 + 5 + 5 + 5, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.Get(msg.Buffer, 0);
+
+ Assert.AreEqual(8, reader.ReadPackedInt32());
+ Assert.AreEqual(250, reader.ReadPackedInt32());
+ Assert.AreEqual(68000, reader.ReadPackedInt32());
+ Assert.AreEqual(60168000, reader.ReadPackedInt32());
+
+ Assert.AreEqual(-68000, reader.ReadPackedInt32());
+ Assert.AreEqual(-250, reader.ReadPackedInt32());
+ Assert.AreEqual(-8, reader.ReadPackedInt32());
+
+ Assert.AreEqual(0, reader.ReadPackedInt32());
+ Assert.AreEqual(-1, reader.ReadPackedInt32());
+ Assert.AreEqual(int.MinValue, reader.ReadPackedInt32());
+ Assert.AreEqual(int.MaxValue, reader.ReadPackedInt32());
+ }
+
+ [TestMethod]
+ public void WritesMessageLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.EndMessage();
+
+ Assert.AreEqual(2 + 1 + 4, msg.Position);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ using (MemoryStream m = new MemoryStream(msg.Buffer, 0, msg.Length))
+ using (BinaryReader reader = new BinaryReader(m))
+ {
+ Assert.AreEqual(4, reader.ReadUInt16()); // Length After Type and Target
+ Assert.AreEqual(1, reader.ReadByte()); // Type
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+ }
+ }
+
+ [TestMethod]
+ public void GetLittleEndian()
+ {
+ Assert.IsTrue(MessageWriter.IsLittleEndian());
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
new file mode 100644
index 0000000..584d08c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
@@ -0,0 +1,292 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+namespace Hazel.UnitTests
+{
+ /// <summary>
+ /// Acts as an intermediate between to sockets.
+ ///
+ /// Use SendToLocalSemaphore and SendToRemoteSemaphore for
+ /// explicit control of packet flow.
+ /// </summary>
+ public class SocketCapture : IDisposable
+ {
+ private IPEndPoint localEndPoint;
+ private readonly IPEndPoint remoteEndPoint;
+
+ private Socket captureSocket;
+
+ private Thread receiveThread;
+ private Thread forLocalThread;
+ private Thread forRemoteThread;
+
+ private ILogger logger;
+
+ private readonly BlockingCollection<ByteSpan> forLocal = new BlockingCollection<ByteSpan>();
+ private readonly BlockingCollection<ByteSpan> forRemote = new BlockingCollection<ByteSpan>();
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToLocalCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForLocalCount => this.forLocal.Count;
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToRemoteCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForRemoteCount => this.forRemote.Count;
+
+ public Semaphore SendToLocalSemaphore = null;
+ public Semaphore SendToRemoteSemaphore = null;
+
+ private CancellationTokenSource cancellationSource = new CancellationTokenSource();
+ private readonly CancellationToken cancellationToken;
+
+ public SocketCapture(IPEndPoint captureEndpoint, IPEndPoint remoteEndPoint, ILogger logger = null)
+ {
+ this.logger = logger ?? new NullLogger();
+ this.cancellationToken = this.cancellationSource.Token;
+
+ this.remoteEndPoint = remoteEndPoint;
+
+ this.captureSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.captureSocket.Bind(captureEndpoint);
+
+ this.receiveThread = new Thread(this.ReceiveLoop);
+ this.receiveThread.Start();
+
+ this.forLocalThread = new Thread(this.SendToLocalLoop);
+ this.forLocalThread.Start();
+
+ this.forRemoteThread = new Thread(this.SendToRemoteLoop);
+ this.forRemoteThread.Start();
+ }
+
+ public void Dispose()
+ {
+ if (this.cancellationSource != null)
+ {
+ this.cancellationSource.Cancel();
+ this.cancellationSource.Dispose();
+ this.cancellationSource = null;
+ }
+
+ if (this.captureSocket != null)
+ {
+ this.captureSocket.Close();
+ this.captureSocket.Dispose();
+ this.captureSocket = null;
+ }
+
+ if (this.receiveThread != null)
+ {
+ this.receiveThread.Join();
+ this.receiveThread = null;
+ }
+
+ if (this.forLocalThread != null)
+ {
+ this.forLocalThread.Join();
+ this.forLocalThread = null;
+ }
+
+ if (this.forRemoteThread != null)
+ {
+ this.forRemoteThread.Join();
+ this.forRemoteThread = null;
+ }
+
+ GC.SuppressFinalize(this);
+ }
+
+ private void ReceiveLoop()
+ {
+ try
+ {
+ IPEndPoint fromEndPoint = new IPEndPoint(IPAddress.Any, 0);
+
+ for (; ; )
+ {
+ byte[] buffer = new byte[2000];
+ EndPoint endPoint = fromEndPoint;
+ int read = this.captureSocket.ReceiveFrom(buffer, ref endPoint);
+ if (read > 0)
+ {
+ // from the remote endpoint?
+ if (IPEndPoint.Equals(endPoint, remoteEndPoint))
+ {
+ this.forLocal.Add(new ByteSpan(buffer, 0, read));
+ }
+ else
+ {
+ this.localEndPoint = (IPEndPoint)endPoint;
+ this.forRemote.Add(new ByteSpan(buffer, 0, read));
+ }
+ }
+ }
+ }
+ catch (SocketException)
+ {
+ }
+ finally
+ {
+ this.forLocal.CompleteAdding();
+ this.forRemote.CompleteAdding();
+ }
+ }
+
+ private void SendToRemoteLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToRemoteSemaphore != null)
+ {
+ if (!this.SendToRemoteSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forRemote.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to remote");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.remoteEndPoint);
+ }
+ }
+ }
+
+ private void SendToLocalLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToLocalSemaphore != null)
+ {
+ if (!this.SendToLocalSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forLocal.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to local");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.localEndPoint);
+ }
+ }
+ }
+
+ public void AssertPacketsToLocalCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forLocal.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forLocal.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void AssertPacketsToRemoteCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forRemote.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forRemote.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void DiscardPacketForLocal(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forLocal.Take();
+ }
+ }
+
+ public void DiscardPacketForRemote(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forRemote.Take();
+ }
+ }
+
+ public ByteSpan PeekPacketForLocal()
+ {
+ return this.forLocal.First();
+ }
+
+ public void ReorderPacketsForRemote(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forRemote.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forRemote.Add(item);
+ }
+ }
+
+ public void ReorderPacketsForLocal(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forLocal.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forLocal.Add(item);
+ }
+ }
+
+ internal void ReleasePacketsForRemote(int numToSend)
+ {
+ var newExpected = this.forRemote.Count - numToSend;
+ this.SendToRemoteSemaphore.Release(numToSend);
+ this.AssertPacketsToRemoteCountEquals(newExpected);
+ }
+
+ internal void ReleasePacketsToLocal(int numToSend)
+ {
+ var newExpected = this.forLocal.Count - numToSend;
+ this.SendToLocalSemaphore.Release(numToSend);
+ this.AssertPacketsToLocalCountEquals(newExpected);
+ }
+
+ internal string PacketsForRemoteToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach (var item in this.forRemote)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+
+ internal string PacketsForLocalToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach(var item in this.forLocal)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs
new file mode 100644
index 0000000..8452a82
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/StatisticsTests.cs
@@ -0,0 +1,159 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class StatisticsTests
+ {
+ [TestMethod]
+ public void SendTests()
+ {
+ ConnectionStatistics statistics = new ConnectionStatistics();
+
+ statistics.LogUnreliableSend(10);
+
+ Assert.AreEqual(1, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(0, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(10, statistics.DataBytesSent);
+
+ statistics.LogReliableSend(5);
+
+ Assert.AreEqual(2, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(15, statistics.DataBytesSent);
+
+ statistics.LogFragmentedSend(6);
+
+ Assert.AreEqual(3, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ statistics.LogAcknowledgementSend();
+
+ Assert.AreEqual(4, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ statistics.LogHelloSend();
+
+ Assert.AreEqual(5, statistics.MessagesSent);
+ Assert.AreEqual(1, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(1, statistics.ReliableMessagesSent);
+ Assert.AreEqual(1, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(1, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(21, statistics.DataBytesSent);
+
+ Assert.AreEqual(0, statistics.MessagesReceived);
+ Assert.AreEqual(0, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(0, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(0, statistics.DataBytesReceived);
+ Assert.AreEqual(0, statistics.TotalBytesReceived);
+
+ statistics.LogPacketSend(11);
+ Assert.AreEqual(11, statistics.TotalBytesSent);
+ }
+
+ [TestMethod]
+ public void ReceiveTests()
+ {
+ ConnectionStatistics statistics = new ConnectionStatistics();
+
+ statistics.LogUnreliableReceive(10, 11);
+
+ Assert.AreEqual(1, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(0, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(10, statistics.DataBytesReceived);
+ Assert.AreEqual(11, statistics.TotalBytesReceived);
+
+ statistics.LogReliableReceive(5, 8);
+
+ Assert.AreEqual(2, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(0, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(15, statistics.DataBytesReceived);
+ Assert.AreEqual(19, statistics.TotalBytesReceived);
+
+ statistics.LogFragmentedReceive(6, 10);
+
+ Assert.AreEqual(3, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(29, statistics.TotalBytesReceived);
+
+ statistics.LogAcknowledgementReceive(4);
+
+ Assert.AreEqual(4, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(0, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(33, statistics.TotalBytesReceived);
+
+ statistics.LogHelloReceive(7);
+
+ Assert.AreEqual(5, statistics.MessagesReceived);
+ Assert.AreEqual(1, statistics.UnreliableMessagesReceived);
+ Assert.AreEqual(1, statistics.ReliableMessagesReceived);
+ Assert.AreEqual(1, statistics.FragmentedMessagesReceived);
+ Assert.AreEqual(1, statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(1, statistics.HelloMessagesReceived);
+
+ Assert.AreEqual(21, statistics.DataBytesReceived);
+ Assert.AreEqual(40, statistics.TotalBytesReceived);
+
+ Assert.AreEqual(0, statistics.MessagesSent);
+ Assert.AreEqual(0, statistics.UnreliableMessagesSent);
+ Assert.AreEqual(0, statistics.ReliableMessagesSent);
+ Assert.AreEqual(0, statistics.FragmentedMessagesSent);
+ Assert.AreEqual(0, statistics.AcknowledgementMessagesSent);
+ Assert.AreEqual(0, statistics.HelloMessagesSent);
+
+ Assert.AreEqual(0, statistics.DataBytesSent);
+ Assert.AreEqual(0, statistics.TotalBytesSent);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs
new file mode 100644
index 0000000..c92b0b7
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/StressTests.cs
@@ -0,0 +1,74 @@
+using System;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+using System.Threading.Tasks;
+using Hazel.Dtls;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class StressTests
+ {
+ // [TestMethod]
+ public void StressTestOpeningConnections()
+ {
+ // Start a listener in another process, or even better,
+ // adjust the target IP and start listening on another computer.
+ var ep = new IPEndPoint(IPAddress.Loopback, 22023);
+ Parallel.For(0, 10000,
+ new ParallelOptions { MaxDegreeOfParallelism = 64 },
+ (i) => {
+
+ var connection = new UdpClientConnection(new TestLogger(), ep);
+ connection.KeepAliveInterval = 50;
+
+ connection.Connect(new byte[5]);
+ });
+ }
+
+ // This was a thing that happened to us a DDoS. Mildly instructional that we straight up ignore it.
+ public void SourceAmpAttack()
+ {
+ var localEp = new IPEndPoint(IPAddress.Any, 11710);
+ var serverEp = new IPEndPoint(IPAddress.Loopback, 11710);
+ using (ThreadLimitedUdpConnectionListener listener = new ThreadLimitedUdpConnectionListener(4, localEp, new ConsoleLogger(true)))
+ {
+ listener.Start();
+
+ Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ socket.DontFragment = false;
+
+ try
+ {
+ const int SIO_UDP_CONNRESET = -1744830452;
+ socket.IOControl(SIO_UDP_CONNRESET, new byte[1], null);
+ }
+ catch { } // Only necessary on Windows
+
+ string byteAsHex = "f23c 92d1 c277 001b 54c2 50c1 0800 4500 0035 7488 0000 3b11 2637 062f ac75 2d4f 0506 a7ea 5607 0021 5e07 ffff ffff 5453 6f75 7263 6520 456e 6769 6e65 2051 7565 7279 00";
+ byte[] bytes = StringToByteArray(byteAsHex.Replace(" ", ""));
+ socket.SendTo(bytes, serverEp);
+
+ while (socket.Poll(50000, SelectMode.SelectRead))
+ {
+ byte[] buffer = new byte[1024];
+ int len = socket.Receive(buffer);
+ Console.WriteLine($"got {len} bytes: " + string.Join(" ", buffer.Select(b => b.ToString("X"))));
+ Console.WriteLine($"got {len} bytes: " + string.Join(" ", buffer.Select(b => (char)b)));
+ }
+ }
+ }
+
+ public static byte[] StringToByteArray(string hex)
+ {
+ return Enumerable.Range(0, hex.Length / 2)
+ .Select(x => Convert.ToByte(hex.Substring(x * 2, 2), 16))
+ .ToArray();
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs b/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs
new file mode 100644
index 0000000..f3e9dfb
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/TestHelper.cs
@@ -0,0 +1,337 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+using Hazel;
+using System.Net;
+using System.Threading;
+using System.Diagnostics;
+using Hazel.Udp.FewerThreads;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public static class TestHelper
+ {
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerToClientTest(ThreadLimitedUdpConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ //Setup listener
+ listener.NewConnection += delegate (NewConnectionEventArgs ncArgs)
+ {
+ ncArgs.Connection.Send(data);
+ };
+
+ listener.Start();
+
+ DataReceivedEventArgs? result = null;
+ //Setup conneciton
+ connection.DataReceived += delegate (DataReceivedEventArgs a)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ try
+ {
+ result = a;
+ }
+ finally
+ {
+ mutex.Set();
+ }
+ };
+
+ connection.Connect();
+
+ //Wait until data is received
+ mutex.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerToClientTest(NetworkConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ //Setup listener
+ listener.NewConnection += delegate (NewConnectionEventArgs ncArgs)
+ {
+ ncArgs.Connection.Send(data);
+ };
+
+ listener.Start();
+
+ DataReceivedEventArgs? result = null;
+ //Setup conneciton
+ connection.DataReceived += delegate (DataReceivedEventArgs a)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ try
+ {
+ result = a;
+ }
+ finally
+ {
+ mutex.Set();
+ }
+ };
+
+ connection.Connect();
+
+ //Wait until data is received
+ mutex.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientToServerTest(NetworkConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ //Setup listener
+ DataReceivedEventArgs? result = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.DataReceived += delegate (DataReceivedEventArgs innerArgs)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ result = innerArgs;
+
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ //Connect
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ connection.Send(data);
+
+ //Wait until data is received
+ mutex2.WaitOne();
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < data.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+
+ /// <summary>
+ /// Runs a general test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientToServerTest(ThreadLimitedUdpConnectionListener listener, Connection connection, int dataSize, SendOption sendOption)
+ {
+ //Setup meta stuff
+ MessageWriter data = BuildData(sendOption, dataSize);
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ //Setup listener
+ DataReceivedEventArgs? result = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.DataReceived += delegate (DataReceivedEventArgs innerArgs)
+ {
+ Trace.WriteLine("Data was received correctly.");
+
+ result = innerArgs;
+
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ //Connect
+ connection.Connect();
+
+ Assert.IsTrue(mutex.WaitOne(100), "Timeout while connecting");
+
+ connection.Send(data);
+
+ //Wait until data is received
+ Assert.IsTrue(mutex2.WaitOne(100), "Timeout while sending data");
+
+ var dataReader = ConvertToMessageReader(data);
+ Assert.AreEqual(dataReader.Length, result.Value.Message.Length);
+ for (int i = 0; i < dataReader.Length; i++)
+ {
+ Assert.AreEqual(dataReader.ReadByte(), result.Value.Message.ReadByte());
+ }
+
+ Assert.AreEqual(sendOption, result.Value.SendOption);
+ }
+
+ /// <summary>
+ /// Runs a server disconnect test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunServerDisconnectTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnect("Testing");
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+ }
+
+ /// <summary>
+ /// Runs a client disconnect test on the given listener and connection.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientDisconnectTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne();
+ }
+
+ /// <summary>
+ /// Ensures a client sends a disconnect packet to the server on Dispose.
+ /// </summary>
+ /// <param name="listener">The listener to test.</param>
+ /// <param name="connection">The connection to test.</param>
+ internal static void RunClientDisconnectOnDisposeTest(NetworkConnectionListener listener, Connection connection)
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ if (!mutex.WaitOne(TimeSpan.FromSeconds(1)))
+ {
+ Assert.Fail("Timeout waiting for client connection");
+ }
+
+ connection.Dispose();
+
+ if (!mutex2.WaitOne(TimeSpan.FromSeconds(1)))
+ {
+ Assert.Fail("Timeout waiting for client disconnect packet");
+ }
+ }
+
+ private static MessageReader ConvertToMessageReader(MessageWriter writer)
+ {
+ var output = new MessageReader();
+ output.Buffer = writer.Buffer;
+ output.Offset = writer.SendOption == SendOption.None ? 1 : 3;
+ output.Length = writer.Length - output.Offset;
+ output.Position = 0;
+
+ return output;
+ }
+
+ /// <summary>
+ /// Builds new data of increaseing value bytes.
+ /// </summary>
+ /// <param name="dataSize">The number of bytes to generate.</param>
+ /// <returns>The data.</returns>
+ static MessageWriter BuildData(SendOption sendOption, int dataSize)
+ {
+ var output = MessageWriter.Get(sendOption);
+ for (int i = 0; i < dataSize; i++)
+ {
+ output.Write((byte)i);
+ }
+
+ return output;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs b/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs
new file mode 100644
index 0000000..01ca893
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/TestLogger.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ public class TestLogger : ILogger
+ {
+ private readonly string prefix;
+
+ public TestLogger(string prefix = "")
+ {
+ this.prefix = prefix;
+ }
+
+ public void WriteVerbose(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[VERBOSE] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][VERBOSE] {msg}");
+ }
+ }
+
+ public void WriteWarning(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[WARN] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][WARN] {msg}");
+ }
+ }
+
+ public void WriteError(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[ERROR] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][ERROR] {msg}");
+ }
+ }
+
+ public void WriteInfo(string msg)
+ {
+ if (string.IsNullOrEmpty(this.prefix))
+ {
+ Console.WriteLine($"[INFO] {msg}");
+ }
+ else
+ {
+ Console.WriteLine($"[{this.prefix}][INFO] {msg}");
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs
new file mode 100644
index 0000000..8b9998f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/ThreadLimitedUdpConnectionTests.cs
@@ -0,0 +1,786 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using Hazel.Udp.FewerThreads;
+using System.Net.Sockets;
+using System.Linq;
+using System.Collections;
+using System.Collections.Generic;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class ThreadLimitedUdpConnectionTests
+ {
+ protected ThreadLimitedUdpConnectionListener CreateListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ return new ThreadLimitedUdpConnectionListener(numWorkers, endPoint, logger, ipMode);
+ }
+
+ protected UdpConnection CreateConnection(IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ return new UdpClientConnection(logger, endPoint, ipMode);
+ }
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger("SERVER")))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger("CLIENT")))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientDisposeDisconnectTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(ep, new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageResendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ using (SemaphoreSlim readLock = new SemaphoreSlim(0, 1))
+ {
+ connection.ResendTimeoutMs = 100;
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message.Duplicate();
+ readLock.Release();
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToRemoteCountEquals(0);
+
+ const int NumberOfPacketsToResend = 4;
+ const int NumberOfTimesToResend = 3;
+ using (capture.SendToRemoteSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ for (int pktCnt = 0; pktCnt < NumberOfPacketsToResend; ++pktCnt)
+ {
+ Console.WriteLine("Send blocked pkt");
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+
+ for (int drops = 0; drops < NumberOfTimesToResend; ++drops)
+ {
+ capture.AssertPacketsToRemoteCountEquals(1);
+ capture.DiscardPacketForRemote();
+ }
+
+ capture.AssertPacketsToRemoteCountEquals(1);
+ capture.SendToRemoteSemaphore.Release(); // Actually let it send.
+
+ Assert.IsTrue(readLock.Wait(1000));
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+
+ output = null;
+ }
+ }
+
+ Assert.AreEqual(NumberOfPacketsToResend * NumberOfTimesToResend, connection.Statistics.MessagesResent);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ {
+ connection.ResendTimeoutMs = 100;
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ const int NumberOfPacketsToSend = 4;
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ for (int pktCnt = 0; pktCnt < NumberOfPacketsToSend; ++pktCnt)
+ {
+ Console.WriteLine("Send blocked pkt");
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+
+ msg.Recycle();
+
+ capture.AssertPacketsToLocalCountEquals(1);
+
+ var ack = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack[1]);
+ Assert.AreEqual(pktCnt + 1, ack[2]);
+ Assert.AreEqual(255, ack[3]);
+
+ capture.SendToLocalSemaphore.Release(); // Actually let it send.
+ capture.AssertPacketsToLocalCountEquals(0);
+ }
+ }
+
+ // +1 for Hello packet
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(1 + NumberOfPacketsToSend, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckWithDropTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger()))
+ {
+ connection.ResendTimeoutMs = 10000; // No resends please
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ using (capture.SendToRemoteSemaphore = new Semaphore(0, int.MaxValue))
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ // Send 3 packets to remote
+ for (int pktCnt = 0; pktCnt < 3; ++pktCnt)
+ {
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ // Drop the middle packet
+ capture.AssertPacketsToRemoteCountEquals(3);
+ capture.ReorderPacketsForRemote(list => list.Sort(SortByPacketId.Instance));
+ Console.WriteLine(capture.PacketsForRemoteToString());
+
+ capture.ReleasePacketsForRemote(1);
+ capture.DiscardPacketForRemote();
+ capture.ReleasePacketsForRemote(1);
+
+ // Receive 2 acks
+ capture.AssertPacketsToLocalCountEquals(2);
+ capture.ReorderPacketsForLocal(list => list.Sort(SortByPacketId.Instance));
+ Console.WriteLine(capture.PacketsForLocalToString());
+
+ var ack1 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack1[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack1[1]);
+ Assert.AreEqual(1, ack1[2]);
+ Assert.AreEqual(255, ack1[3]);
+ capture.ReleasePacketsToLocal(1);
+
+ var ack2 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack2[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack2[1]);
+ Assert.AreEqual(3, ack2[2]);
+ Assert.AreEqual(254, ack2[3]); // The server is expecting packet 2
+ capture.ReleasePacketsToLocal(1);
+ }
+
+ // +1 for Hello packet, +2 for reliable
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(3, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ [TestMethod]
+ public void UdpReliableMessageAckFillsDroppedAcksTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ var listenerEp = new IPEndPoint(IPAddress.Loopback, 4296);
+ var captureEp = new IPEndPoint(IPAddress.Loopback, 4297);
+
+ using (SocketCapture capture = new SocketCapture(captureEp, listenerEp, new TestLogger()))
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, listenerEp.Port), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(captureEp, new TestLogger("Client")))
+ {
+ connection.ResendTimeoutMs = 10000; // No resends please
+ connection.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ var udpConn = (UdpConnection)e.Connection;
+ udpConn.KeepAliveInterval = Timeout.Infinite; // Don't let pings interfere.
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ capture.AssertPacketsToLocalCountEquals(0);
+
+ using (capture.SendToLocalSemaphore = new Semaphore(0, int.MaxValue))
+ {
+ // Send 4 packets to remote
+ for (int pktCnt = 0; pktCnt < 4; ++pktCnt)
+ {
+ var msg = MessageWriter.Get(SendOption.Reliable);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ // Receive 4 acks, Drop the middle two
+ capture.AssertPacketsToLocalCountEquals(4);
+ capture.ReorderPacketsForLocal(list => list.Sort(SortByPacketId.Instance));
+
+ var ack1 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack1[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack1[1]);
+ Assert.AreEqual(1, ack1[2]);
+ Assert.AreEqual(255, ack1[3]);
+ capture.ReleasePacketsToLocal(1);
+
+ capture.DiscardPacketForLocal(2);
+
+ var ack4 = capture.PeekPacketForLocal();
+ Assert.AreEqual(10, ack4[0]); // enum SendOptionInternal.Acknowledgement
+ Assert.AreEqual(0, ack4[1]);
+ Assert.AreEqual(4, ack4[2]);
+ Assert.AreEqual(255, ack4[3]);
+ capture.ReleasePacketsToLocal(1);
+ }
+
+ // +1 for Hello packet, +4 for reliable despite the dropped ack
+ Thread.Sleep(100); // The final ack has to actually be processed.
+ Assert.AreEqual(3, connection.Statistics.AcknowledgementMessagesReceived);
+ Assert.AreEqual(5, connection.Statistics.ReliablePacketsAcknowledged);
+ }
+ }
+
+ private class SortByPacketId : IComparer<ByteSpan>
+ {
+ public static SortByPacketId Instance = new SortByPacketId();
+
+ public int Compare(ByteSpan x, ByteSpan y)
+ {
+ ushort xId = BitConverter.ToUInt16(x.GetUnderlyingArray(), 1);
+ ushort yId = BitConverter.ToUInt16(y.GetUnderlyingArray(), 1);
+ return xId.CompareTo(yId);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ Console.Write($"Client sent {connection.Statistics.TotalBytesSent} bytes ");
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, listener.ReceiveQueueLength);
+ Assert.IsTrue(connects <= 1, $"Too many connections: {connects}");
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+
+ using (ThreadLimitedUdpConnectionListener listener2 = this.CreateListener(4, new IPEndPoint(IPAddress.IPv6Any, 4296), new ConsoleLogger(true), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine($"Connection: {evt.Connection.EndPoint}");
+ };
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger()))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection2 = this.CreateConnection(new IPEndPoint(IPAddress.IPv6Loopback, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection2.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection2.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.IPv6Any, 4296), new TestLogger(), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), new TestLogger(), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public virtual void KeepAliveClientTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread timeoutThread = new Thread(() =>
+ {
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+ mutex.Set();
+ });
+ timeoutThread.Start();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+ ManualResetEvent mutex2 = new ManualResetEvent(false);
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ args.Connection.Disconnected += delegate (object sender2, DisconnectedEventArgs args2)
+ {
+ mutex2.Set();
+ };
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ connection.Disconnect("Testing");
+
+ mutex2.WaitOne(1000);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ SemaphoreSlim mutex = new SemaphoreSlim(0, 100);
+ ManualResetEventSlim serverMutex = new ManualResetEventSlim(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ mutex.Release();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ mutex.Release();
+
+ // This has to be on a new thread because the client will go straight from Connecting to NotConnected
+ ThreadPool.QueueUserWorkItem(_ =>
+ {
+ serverMutex.Wait(500);
+ args.Connection.Disconnect("Testing");
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+
+ serverMutex.Set();
+
+ mutex.Wait(500);
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (ThreadLimitedUdpConnectionListener listener = this.CreateListener(2, new IPEndPoint(IPAddress.Any, 4296), new TestLogger()))
+ using (UdpConnection connection = this.CreateConnection(new IPEndPoint(IPAddress.Loopback, 4296), new TestLogger()))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne(5000);
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs
new file mode 100644
index 0000000..3bee911
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UPnPTests.cs
@@ -0,0 +1,54 @@
+using System;
+using Hazel.UPnP;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ // [TestClass]
+ // TODO: These tests are super flaky because of hardware differences. Not sure what can be done.
+ public class UPnPTests
+ {
+ [TestMethod]
+ public void CanForwardPort()
+ {
+ using (UPnPHelper dut = new UPnPHelper(Logger.Instance))
+ {
+ Assert.IsTrue(dut.ForwardPort(22023, "Hazel Test"));
+ }
+ }
+
+ [TestMethod]
+ public void CanDeletePort()
+ {
+ using (UPnPHelper dut = new UPnPHelper(Logger.Instance))
+ {
+ Assert.IsTrue(dut.DeleteForwardingRule(22023));
+ }
+ }
+ }
+
+ public class Logger : ILogger
+ {
+ public static readonly ILogger Instance = new Logger();
+
+ public void WriteVerbose(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteWarning(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteError(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+
+ public void WriteInfo(string msg)
+ {
+ Console.WriteLine(msg);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs
new file mode 100644
index 0000000..1e865a8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTestHarness.cs
@@ -0,0 +1,60 @@
+using Hazel.Udp;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ internal class UdpConnectionTestHarness : UdpConnection
+ {
+ public List<MessageReader> BytesSent = new List<MessageReader>();
+
+ public UdpConnectionTestHarness() : base(new TestLogger())
+ {
+ }
+
+ public ushort ReliableReceiveLast => this.reliableReceiveLast;
+
+
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.State = ConnectionState.Connected;
+ }
+
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connected;
+ }
+
+ protected override bool SendDisconnect(MessageWriter writer)
+ {
+ lock (this)
+ {
+ if (this.State != ConnectionState.Connected)
+ {
+ return false;
+ }
+
+ this.State = ConnectionState.NotConnected;
+ }
+
+ return true;
+ }
+
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ this.BytesSent.Add(MessageReader.Get(bytes));
+ }
+
+ public void Test_Receive(MessageWriter msg)
+ {
+ byte[] buffer = new byte[msg.Length];
+ Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length);
+
+ var data = MessageReader.Get(buffer);
+ this.HandleReceive(data, data.Length);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs
new file mode 100644
index 0000000..1ccd561
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpConnectionTests.cs
@@ -0,0 +1,514 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using System.Net.Sockets;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UdpConnectionTests
+ {
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), ep))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ };
+
+ connection.Connect(TestData);
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message;
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (UdpConnectionListener listener2 = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine("v6 connection: " + ((NetworkConnection)evt.Connection).GetIP4Address());
+ };
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296)))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.IPv6Loopback, 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to non-hello connections.
+ /// </summary>
+ [TestMethod]
+ public void FalseConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)32;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(1, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void PingDisconnectClientTest()
+ {
+#if DEBUG
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ // After connecting, quietly stop responding to all messages to fake connection loss.
+ Thread.Sleep(10);
+ listener.TestDropRate = 1;
+
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.NotConnected, connection.State);
+ Assert.AreEqual(3 * connection.MissingPingsUntilDisconnect + 4, connection.Statistics.TotalBytesSent); // + 4 for connecting overhead
+ }
+#else
+ Assert.Inconclusive("Only works in DEBUG");
+#endif
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Test that a disconnect is sent when the client is disposed.
+ /// </summary>
+ public void ClientDisconnectOnDisposeTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectOnDisposeTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UdpClientConnection(new TestLogger("Client"), new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ // As it turns out, the UdpConnectionListener can have an issue on loopback where the disconnect can happen before the hello confirm
+ // Tossing it on a different thread makes this test more reliable. Perhaps something to think about elsewhere though.
+ Task.Run(async () =>
+ {
+ await Task.Delay(1);
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs
new file mode 100644
index 0000000..ede7698
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UdpReliabilityTests.cs
@@ -0,0 +1,116 @@
+using System;
+using System.Collections.Generic;
+using Hazel.Udp;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UdpReliabilityTests
+ {
+ [TestMethod]
+ public void TestReliableWrapOffByOne()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ Assert.AreEqual(ushort.MaxValue, dut.ReliableReceiveLast);
+
+ SetReliableId(data, 10);
+ dut.Test_Receive(data);
+
+ // This message may not be received if there is an off-by-one error when marking missed pkts up to 10.
+ SetReliableId(data, 9);
+ dut.Test_Receive(data);
+
+ // Both messages should be received.
+ Assert.AreEqual(2, messagesReceived.Count);
+ messagesReceived.Clear();
+
+ Assert.AreEqual(2, dut.BytesSent.Count);
+ dut.BytesSent.Clear();
+ }
+
+ [TestMethod]
+ public void TestThatAllMessagesAreReceived()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ for (int i = 0; i < ushort.MaxValue * 2; ++i)
+ {
+ // Send a new message, it should be received and ack'd
+ SetReliableId(data, i);
+ dut.Test_Receive(data);
+
+ // Resend an old message, it should be ignored
+ if (i > 2)
+ {
+ SetReliableId(data, i - 1);
+ dut.Test_Receive(data);
+
+ // It should still be ack'd
+ Assert.AreEqual(2, dut.BytesSent.Count);
+ dut.BytesSent.RemoveAt(1);
+ }
+
+ Assert.AreEqual(1, messagesReceived.Count);
+ messagesReceived.Clear();
+
+ Assert.AreEqual(1, dut.BytesSent.Count);
+ dut.BytesSent.Clear();
+ }
+ }
+
+ [TestMethod]
+ public void TestAcksForNotReceivedMessages()
+ {
+ List<MessageReader> messagesReceived = new List<MessageReader>();
+
+ UdpConnectionTestHarness dut = new UdpConnectionTestHarness();
+ dut.DataReceived += evt =>
+ {
+ messagesReceived.Add(evt.Message);
+ };
+
+ MessageWriter data = MessageWriter.Get(SendOption.Reliable);
+
+ SetReliableId(data, 1);
+ dut.Test_Receive(data);
+
+ SetReliableId(data, 3);
+ dut.Test_Receive(data);
+
+ MessageReader ackPacket = dut.BytesSent[1];
+ // Must be ack
+ Assert.AreEqual(4, ackPacket.Length);
+
+ byte recentPackets = ackPacket.Buffer[3];
+ // Last packet was not received
+ Assert.AreEqual(0, recentPackets & 1);
+ // The packet before that was.
+ Assert.AreEqual(1, (recentPackets >> 1) & 1);
+ }
+
+ private static void SetReliableId(MessageWriter data, int i)
+ {
+ ushort id = (ushort)i;
+ data.Buffer[1] = (byte)(id >> 8);
+ data.Buffer[2] = (byte)id;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs b/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs
new file mode 100644
index 0000000..0745578
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/UnityUdpConnectionTests.cs
@@ -0,0 +1,489 @@
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Net;
+using System.Threading;
+using Hazel.Udp;
+using System.Net.Sockets;
+using System.Threading.Tasks;
+
+namespace Hazel.UnitTests
+{
+ [TestClass]
+ public class UnityUdpConnectionTests
+ {
+ private ILogger logger = new ConsoleLogger(true);
+
+ [TestMethod]
+ public void ServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+ connection.Disconnected += (o, evt) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ listener.Dispose();
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(clientDisconnected);
+ Assert.IsFalse(serverDisconnected);
+ }
+ }
+
+ [TestMethod]
+ public void ClientServerDisposeDisconnectsTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ bool serverConnected = false;
+ bool serverDisconnected = false;
+ bool clientDisconnected = false;
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.NewConnection += (evt) =>
+ {
+ serverConnected = true;
+ evt.Connection.Disconnected += (o, et) => serverDisconnected = true;
+ };
+
+ connection.Disconnected += (o, et) => clientDisconnected = true;
+
+ listener.Start();
+ connection.Connect();
+
+ Thread.Sleep(100); // Gotta wait for the server to set up the events.
+ connection.Dispose();
+
+ Thread.Sleep(100);
+
+ Assert.IsTrue(serverConnected);
+ Assert.IsTrue(serverDisconnected);
+ Assert.IsFalse(clientDisconnected);
+ }
+ }
+
+ /// <summary>
+ /// Tests the fields on UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpFieldTest()
+ {
+ IPEndPoint ep = new IPEndPoint(IPAddress.Loopback, 4296);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, ep))
+ {
+ listener.Start();
+
+ connection.Connect();
+
+ //Connection fields
+ Assert.AreEqual(ep, connection.EndPoint);
+
+ //UdpConnection fields
+ Assert.AreEqual(new IPEndPoint(IPAddress.Loopback, 4296), connection.EndPoint);
+ Assert.AreEqual(1, connection.Statistics.DataBytesSent);
+ Assert.AreEqual(0, connection.Statistics.DataBytesReceived);
+ }
+ }
+
+ [TestMethod]
+ public void UdpHandshakeTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+
+ using (ManualResetEventSlim mutex = new ManualResetEventSlim(false))
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ output = e.HandshakeData.Duplicate();
+ mutex.Set();
+ };
+
+ connection.Connect(TestData);
+ mutex.Wait(5000);
+
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ [TestMethod]
+ public void UdpUnreliableMessageSendTest()
+ {
+ byte[] TestData = new byte[] { 1, 2, 3, 4, 5, 6 };
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ MessageReader output = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs e)
+ {
+ e.Connection.DataReceived += delegate (DataReceivedEventArgs evt)
+ {
+ output = evt.Message;
+ };
+ };
+
+ listener.Start();
+ connection.Connect();
+
+ for (int i = 0; i < 4; ++i)
+ {
+ var msg = MessageWriter.Get(SendOption.None);
+ msg.Write(TestData);
+ connection.Send(msg);
+ msg.Recycle();
+ }
+
+ Thread.Sleep(10);
+ for (int i = 0; i < TestData.Length; ++i)
+ {
+ Assert.AreEqual(TestData[i], output.ReadByte());
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv4ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void MixedConnectionTest()
+ {
+ using (UdpConnectionListener listener2 = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener2.Start();
+
+ listener2.NewConnection += (evt) =>
+ {
+ Console.WriteLine("v6 connection: " + ((NetworkConnection)evt.Connection).GetIP4Address());
+ };
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296)))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.IPv6Loopback, 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to non-hello connections.
+ /// </summary>
+ [TestMethod]
+ public void FalseConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)32;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(0, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests IPv4 resilience to multiple hellos.
+ /// </summary>
+ [TestMethod]
+ public void ConnectLikeAJerkTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp))
+ {
+ int connects = 0;
+ listener.NewConnection += (obj) =>
+ {
+ Interlocked.Increment(ref connects);
+ };
+
+ listener.Start();
+
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ var bytes = new byte[2];
+ bytes[0] = (byte)UdpSendOption.Hello;
+ for (int i = 0; i < 10; ++i)
+ {
+ socket.SendTo(bytes, new IPEndPoint(IPAddress.Loopback, 4296));
+ }
+
+ Thread.Sleep(500);
+
+ Assert.AreEqual(1, connects);
+ }
+ }
+
+ /// <summary>
+ /// Tests dual mode connectivity.
+ /// </summary>
+ [TestMethod]
+ public void UdpIPv6ConnectionTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.IPv6Any, 4296), IPMode.IPv6))
+ {
+ listener.Start();
+
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Parse("127.0.0.1"), 4296), IPMode.IPv6))
+ {
+ connection.Connect();
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableServerToClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerToClientTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client unreliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpUnreliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.None);
+ }
+ }
+
+ /// <summary>
+ /// Tests server to client reliable communication on the UdpConnection.
+ /// </summary>
+ [TestMethod]
+ public void UdpReliableClientToServerTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientToServerTest(listener, connection, 10, SendOption.Reliable);
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveClientTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ listener.Start();
+
+ connection.Connect();
+ connection.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ Assert.AreEqual(ConnectionState.Connected, connection.State);
+ Assert.IsTrue(
+ connection.Statistics.TotalBytesSent >= 30 &&
+ connection.Statistics.TotalBytesSent <= 50,
+ "Sent: " + connection.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests the keepalive functionality from the client,
+ /// </summary>
+ [TestMethod]
+ public void KeepAliveServerTest()
+ {
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ UdpConnection client = null;
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ client = (UdpConnection)args.Connection;
+ client.KeepAliveInterval = 100;
+
+ Thread.Sleep(1050); //Enough time for ~10 keep alive packets
+
+ mutex.Set();
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.AreEqual(ConnectionState.Connected, client.State);
+
+ Assert.IsTrue(
+ client.Statistics.TotalBytesSent >= 27 &&
+ client.Statistics.TotalBytesSent <= 50,
+ "Sent: " + client.Statistics.TotalBytesSent
+ );
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the client.
+ /// </summary>
+ [TestMethod]
+ public void ClientDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Test that a disconnect is sent when the client is disposed.
+ /// </summary>
+ public void ClientDisconnectOnDisposeTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunClientDisconnectOnDisposeTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ TestHelper.RunServerDisconnectTest(listener, connection);
+ }
+ }
+
+ /// <summary>
+ /// Tests disconnection from the server.
+ /// </summary>
+ [TestMethod]
+ public void ServerExtraDataDisconnectTest()
+ {
+ using (UdpConnectionListener listener = new UdpConnectionListener(new IPEndPoint(IPAddress.Any, 4296)))
+ using (UdpConnection connection = new UnityUdpClientConnection(logger, new IPEndPoint(IPAddress.Loopback, 4296)))
+ {
+ string received = null;
+ ManualResetEvent mutex = new ManualResetEvent(false);
+
+ connection.Disconnected += delegate (object sender, DisconnectedEventArgs args)
+ {
+ // We don't own the message, we have to read the string now
+ received = args.Message.ReadString();
+ mutex.Set();
+ };
+
+ listener.NewConnection += delegate (NewConnectionEventArgs args)
+ {
+ // As it turns out, the UdpConnectionListener can have an issue on loopback where the disconnect can happen before the hello confirm
+ // Tossing it on a different thread makes this test more reliable. Perhaps something to think about elsewhere though.
+ Task.Run(async () =>
+ {
+ await Task.Delay(1);
+ MessageWriter writer = MessageWriter.Get(SendOption.None);
+ writer.Write("Goodbye");
+ args.Connection.Disconnect("Testing", writer);
+ });
+ };
+
+ listener.Start();
+
+ connection.Connect();
+
+ mutex.WaitOne();
+
+ Assert.IsNotNull(received);
+ Assert.AreEqual("Goodbye", received);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs b/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs
new file mode 100644
index 0000000..df62b80
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/Utils.cs
@@ -0,0 +1,99 @@
+using Org.BouncyCastle.Crypto;
+using Org.BouncyCastle.Crypto.Parameters;
+using Org.BouncyCastle.OpenSsl;
+using Org.BouncyCastle.Security;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.IO;
+using System.Security.Cryptography;
+
+namespace Hazel.UnitTests
+{
+ static class Utils
+ {
+ /// <summary>
+ /// Hex encode a byte array (lower case)
+ /// </summary>
+ public static string BytesToHex(byte[] data)
+ {
+ string chars = "0123456789abcdef";
+
+ StringBuilder sb = new StringBuilder(data.Length * 2);
+ for (int ii = 0, nn = data.Length; ii != nn; ++ii)
+ {
+ sb.Append(chars[data[ii] >> 4]);
+ sb.Append(chars[data[ii] & 0xF]);
+ }
+
+ return sb.ToString().ToLower();
+ }
+
+ /// <summary>
+ /// Decode a hex string to a byte array (lowercase)
+ /// </summary>
+ public static byte[] HexToBytes(string hex)
+ {
+ hex = hex.ToLower();
+ hex = hex = string.Concat(hex.Where(c => !char.IsWhiteSpace(c)));
+
+ byte[] output = new byte[hex.Length / 2];
+
+ for (int ii = 0; ii != hex.Length; ++ii)
+ {
+ byte nibble;
+
+ char c = hex[ii];
+ if (c >= 'a')
+ {
+ nibble = (byte)(0x0A + c - 'a');
+ }
+ else
+ {
+ nibble = (byte)(c - '0');
+ }
+
+ if ((ii & 1) == 0)
+ {
+ output[ii / 2] = (byte)(nibble << 4);
+ }
+ else
+ {
+ output[ii / 2] |= nibble;
+ }
+ }
+
+ return output;
+ }
+
+ public static byte[] DecodePEM(string pemData)
+ {
+ List<byte> result = new List<byte>();
+
+ pemData = pemData.Replace("\r", "");
+ string[] lines = pemData.Split('\n');
+ foreach (string line in lines)
+ {
+ if (line.StartsWith("-----"))
+ {
+ continue;
+ }
+
+ byte[] lineData = Convert.FromBase64String(line);
+ result.AddRange(lineData);
+ }
+
+ return result.ToArray();
+ }
+
+ public static RSA DecodeRSAKeyFromPEM(string pemData)
+ {
+ var pemReader = new PemReader(new StringReader(pemData));
+ var parameters = DotNetUtilities.ToRSAParameters((RsaPrivateCrtKeyParameters)pemReader.ReadObject());
+ var rsa = RSA.Create();
+ rsa.ImportParameters(parameters);
+ return rsa;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel.sln b/Tools/Hazel-Networking/Hazel.sln
new file mode 100644
index 0000000..bc27b0a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.sln
@@ -0,0 +1,31 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 16
+VisualStudioVersion = 16.0.30523.141
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Hazel", "Hazel\Hazel.csproj", "{02CFBD30-D77D-400F-94B2-700F60EFDD7F}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Hazel.UnitTests", "Hazel.UnitTests\Hazel.UnitTests.csproj", "{1394E4CA-E17A-42F5-9216-8046ACA8D16B}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {02CFBD30-D77D-400F-94B2-700F60EFDD7F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {02CFBD30-D77D-400F-94B2-700F60EFDD7F}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {02CFBD30-D77D-400F-94B2-700F60EFDD7F}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {02CFBD30-D77D-400F-94B2-700F60EFDD7F}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1394E4CA-E17A-42F5-9216-8046ACA8D16B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1394E4CA-E17A-42F5-9216-8046ACA8D16B}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1394E4CA-E17A-42F5-9216-8046ACA8D16B}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1394E4CA-E17A-42F5-9216-8046ACA8D16B}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {8AC8C0A6-FB6B-4E63-9042-EDF10C6A51B8}
+ EndGlobalSection
+EndGlobal
diff --git a/Tools/Hazel-Networking/Hazel/ByteSpan.cs b/Tools/Hazel-Networking/Hazel/ByteSpan.cs
new file mode 100644
index 0000000..7cfa3b5
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ByteSpan.cs
@@ -0,0 +1,191 @@
+using System;
+
+namespace Hazel
+{
+ /// <summary>
+ /// This is a minimal implementation of `System.Span` in .NET 5.0
+ /// </summary>
+ public struct ByteSpan
+ {
+ private readonly byte[] array_;
+
+ /// <summary>
+ /// Createa a new span object containing an entire array
+ /// </summary>
+ public ByteSpan(byte[] array)
+ {
+ if (array == null)
+ {
+ this.array_ = null;
+ this.Offset = 0;
+ this.Length = 0;
+ }
+ else
+ {
+ this.array_ = array;
+ this.Offset = 0;
+ this.Length = array.Length;
+ }
+ }
+
+ /// <summary>
+ /// Creates a new span object containing a subset of an array
+ /// </summary>
+ public ByteSpan(byte[] array, int offset, int length)
+ {
+ if (array == null)
+ {
+ if (offset != 0)
+ {
+ throw new ArgumentException("Invalid offset", nameof(offset));
+ }
+ if (length != 0)
+ {
+ throw new ArgumentException("Invalid length", nameof(offset));
+ }
+
+ this.array_ = null;
+ this.Offset = 0;
+ this.Length = 0;
+ }
+ else
+ {
+ if (offset < 0 || offset > array.Length)
+ {
+ throw new ArgumentException("Invalid offset", nameof(offset));
+ }
+ if (length < 0)
+ {
+ throw new ArgumentException($"Invalid length: {length}", nameof(length));
+ }
+ if ((offset + length) > array.Length)
+ {
+ throw new ArgumentException($"Invalid length. Length: {length} Offset: {offset} Array size: {array.Length}", nameof(length));
+ }
+
+ this.array_ = array;
+ this.Offset = offset;
+ this.Length = length;
+ }
+ }
+
+ /// <summary>
+ /// Returns the underlying array.
+ ///
+ /// WARNING: This does not return the span, but the entire underlying storage block
+ /// </summary>
+ public byte[] GetUnderlyingArray()
+ {
+ return this.array_;
+ }
+
+ /// <summary>
+ /// Returns the offset into the underlying array
+ /// </summary>
+ public int Offset { get; }
+
+ /// <summary>
+ /// Returns the length of the current span
+ /// </summary>
+ public int Length { get; }
+
+ /// <summary>
+ /// Gets the span element at the specified index
+ /// </summary>
+ public byte this[int index]
+ {
+ get
+ {
+ if (index < 0 || index >= this.Length)
+ {
+ throw new IndexOutOfRangeException();
+ }
+
+ return this.array_[this.Offset + index];
+ }
+ set
+ {
+ if (index < 0 || index >= this.Length)
+ {
+ throw new IndexOutOfRangeException();
+ }
+
+ this.array_[this.Offset + index] = value;
+ }
+ }
+
+ /// <summary>
+ /// Create a new span that is a subset of this span [offset, this.Length-offset)
+ /// </summary>
+ public ByteSpan Slice(int offset)
+ {
+ return Slice(offset, this.Length - offset);
+ }
+
+ /// <summary>
+ /// Create a new span that is a subset of this span [offset, length)
+ /// </summary>
+ public ByteSpan Slice(int offset, int length)
+ {
+ return new ByteSpan(this.array_, this.Offset + offset, length);
+ }
+
+ /// <summary>
+ /// Copies the contents of the span to an array
+ /// </summary>
+ public void CopyTo(byte[] array, int offset)
+ {
+ CopyTo(new ByteSpan(array, offset, array.Length - offset));
+ }
+
+ /// <summary>
+ /// Copies the contents of the span to another span
+ /// </summary>
+ public void CopyTo(ByteSpan destination)
+ {
+ if (destination.Length < this.Length)
+ {
+ throw new ArgumentException("Destination span is shorter than source", nameof(destination));
+ }
+
+ if (Length > 0)
+ {
+ Buffer.BlockCopy(this.array_, this.Offset, destination.array_, destination.Offset, this.Length);
+ }
+ }
+
+ /// <summary>
+ /// Create a new array with the contents of this span
+ /// </summary>
+ public byte[] ToArray()
+ {
+ byte[] result = new byte[Length];
+ CopyTo(result);
+ return result;
+ }
+
+ public override string ToString()
+ {
+ return string.Join(" ", this.ToArray());
+ }
+
+ /// <summary>
+ /// Implicit conversion from byte[] -> ByteSpan
+ /// </summary>
+ public static implicit operator ByteSpan(byte[] array)
+ {
+ return new ByteSpan(array);
+ }
+
+ /// <summary>
+ /// Retuns an empty span object
+ /// </summary>
+ public static ByteSpan Empty
+ {
+ get
+ {
+ return new ByteSpan(null);
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs
new file mode 100644
index 0000000..3a9d1ac
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ByteSpanExtensions.cs
@@ -0,0 +1,131 @@
+namespace Hazel
+{
+ /// <summary>
+ /// Extension functions for (en/de)coding integer values
+ /// </summary>
+ public static class ByteSpanBigEndianExtensions
+ {
+ // Write a 16-bit integer in big-endian format to output[0..2)
+ public static void WriteBigEndian16(this ByteSpan output, ushort value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 8);
+ output[offset + 1] = (byte)(value >> 0);
+ }
+
+ // Write a 24-bit integer in big-endian format to output[0..3)
+ public static void WriteBigEndian24(this ByteSpan output, uint value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 16);
+ output[offset + 1] = (byte)(value >> 8);
+ output[offset + 2] = (byte)(value >> 0);
+ }
+
+ // Write a 32-bit integer in big-endian format to output[0..4)
+ public static void WriteBigEndian32(this ByteSpan output, uint value, int offset)
+ {
+ output[offset + 0] = (byte)(value >> 24);
+ output[offset + 1] = (byte)(value >> 16);
+ output[offset + 2] = (byte)(value >> 8);
+ output[offset + 3] = (byte)(value >> 0);
+ }
+
+ // Write a 48-bit integer in big-endian format to output[0..6)
+ public static void WriteBigEndian48(this ByteSpan output, ulong value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 40);
+ output[offset + 1] = (byte)(value >> 32);
+ output[offset + 2] = (byte)(value >> 24);
+ output[offset + 3] = (byte)(value >> 16);
+ output[offset + 4] = (byte)(value >> 8);
+ output[offset + 5] = (byte)(value >> 0);
+ }
+
+ // Write a 64-bit integer in big-endian format to output[0..8)
+ public static void WriteBigEndian64(this ByteSpan output, ulong value, int offset = 0)
+ {
+ output[offset + 0] = (byte)(value >> 56);
+ output[offset + 1] = (byte)(value >> 48);
+ output[offset + 2] = (byte)(value >> 40);
+ output[offset + 3] = (byte)(value >> 32);
+ output[offset + 4] = (byte)(value >> 24);
+ output[offset + 5] = (byte)(value >> 16);
+ output[offset + 6] = (byte)(value >> 8);
+ output[offset + 7] = (byte)(value >> 0);
+ }
+
+ // Read a 16-bit integer in big-endian format from input[0..2)
+ public static ushort ReadBigEndian16(this ByteSpan input, int offset = 0)
+ {
+ ushort value = 0;
+ value |= (ushort)(input[offset + 0] << 8);
+ value |= (ushort)(input[offset + 1] << 0);
+ return value;
+ }
+
+ // Read a 24-bit integer in big-endian format from input[0..3)
+ public static uint ReadBigEndian24(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0] << 16;
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 0;
+ return value;
+ }
+
+ // Read a 48-bit integer in big-endian format from input[0..3)
+ public static ulong ReadBigEndian48(this ByteSpan input, int offset = 0)
+ {
+ ulong value = 0;
+ value |= (ulong)input[offset + 0] << 40;
+ value |= (ulong)input[offset + 1] << 32;
+ value |= (ulong)input[offset + 2] << 24;
+ value |= (ulong)input[offset + 3] << 16;
+ value |= (ulong)input[offset + 4] << 8;
+ value |= (ulong)input[offset + 5] << 0;
+ return value;
+ }
+ }
+
+ public static class ByteSpanLittleEndianExtensions
+ {
+ // Read a 24-bit integer in little-endian format from input[0..3)
+ public static uint ReadLittleEndian24(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0];
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 16;
+ return value;
+ }
+
+ // Read a 24-bit integer in little-endian format from input[0..4)
+ public static uint ReadLittleEndian32(this ByteSpan input, int offset = 0)
+ {
+ uint value = 0;
+ value |= (uint)input[offset + 0];
+ value |= (uint)input[offset + 1] << 8;
+ value |= (uint)input[offset + 2] << 16;
+ value |= (uint)input[offset + 3] << 24;
+ return value;
+ }
+
+ /// <summary>
+ /// Reuse an existing span if there is enough space,
+ /// otherwise allocate new storage
+ /// </summary>
+ /// <param name="source">
+ /// Source span we should attempt to reuse
+ /// </param>
+ /// <param name="requiredSize">Required size (bytes)</param>
+ public static ByteSpan ReuseSpanIfPossible(this ByteSpan source, int requiredSize)
+ {
+ if (source.Length >= requiredSize)
+ {
+ return source.Slice(0, requiredSize);
+ }
+
+ return new byte[requiredSize];
+ }
+
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Connection.cs b/Tools/Hazel-Networking/Hazel/Connection.cs
new file mode 100644
index 0000000..da2f59a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Connection.cs
@@ -0,0 +1,234 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Net.Sockets;
+using System.Net;
+using System.Threading;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Base class for all connections.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Connection is the base class for all connections that Hazel can make. It provides common functionality and a
+ /// standard interface to allow connections to be swapped easily.
+ /// </para>
+ /// <para>
+ /// Any class inheriting from Connection should provide the 3 standard guarantees that Hazel provides:
+ /// <list type="bullet">
+ /// <item>
+ /// <description>Thread Safe</description>
+ /// </item>
+ /// <item>
+ /// <description>Connection Orientated</description>
+ /// </item>
+ /// <item>
+ /// <description>Packet/Message Based</description>
+ /// </item>
+ /// </list>
+ /// </para>
+ /// </remarks>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class Connection : IDisposable
+ {
+ /// <summary>
+ /// Called when a message has been received.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// DataReceived is invoked everytime a message is received from the end point of this connection, the message
+ /// that was received can be found in the <see cref="DataReceivedEventArgs"/> alongside other information from the
+ /// event.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpClientExample.cs"/>
+ /// </example>
+ public event Action<DataReceivedEventArgs> DataReceived;
+
+ public int TestLagMs = -1;
+ public int TestDropRate = 0;
+ protected int testDropCount = 0;
+
+ /// <summary>
+ /// Called when the end point disconnects or an error occurs.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Disconnected is invoked when the connection is closed due to an exception occuring or because the remote
+ /// end point disconnected. If it was invoked due to an exception occuring then the exception is available
+ /// in the <see cref="DisconnectedEventArgs"/> passed with the event.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpClientExample.cs"/>
+ /// </example>
+ public event EventHandler<DisconnectedEventArgs> Disconnected;
+
+ /// <summary>
+ /// The remote end point of this Connection.
+ /// </summary>
+ /// <remarks>
+ /// This is the end point that this connection is connected to (i.e. the other device). This returns an abstract
+ /// <see cref="ConnectionEndPoint"/> which can then be cast to an appropriate end point depending on the
+ /// connection type.
+ /// </remarks>
+ public IPEndPoint EndPoint { get; protected set; }
+
+ public IPMode IPMode { get; protected set; }
+
+ /// <summary>
+ /// The traffic statistics about this Connection.
+ /// </summary>
+ /// <remarks>
+ /// Contains statistics about the number of messages and bytes sent and received by this connection.
+ /// </remarks>
+ public ConnectionStatistics Statistics { get; protected set; }
+
+ /// <summary>
+ /// The state of this connection.
+ /// </summary>
+ /// <remarks>
+ /// All implementers should be aware that when this is set to ConnectionState.Connected it will
+ /// release all threads that are blocked on <see cref="WaitOnConnect"/>.
+ /// </remarks>
+ public ConnectionState State
+ {
+ get
+ {
+ return this._state;
+ }
+
+ protected set
+ {
+ this._state = value;
+ this.SetState(value);
+ }
+ }
+
+ protected ConnectionState _state;
+ protected virtual void SetState(ConnectionState state) { }
+
+ /// <summary>
+ /// Constructor that initializes the ConnecitonStatistics object.
+ /// </summary>
+ /// <remarks>
+ /// This constructor initialises <see cref="Statistics"/> with empty statistics and sets <see cref="State"/> to
+ /// <see cref="ConnectionState.NotConnected"/>.
+ /// </remarks>
+ protected Connection()
+ {
+ this.Statistics = new ConnectionStatistics();
+ this.State = ConnectionState.NotConnected;
+ }
+
+ /// <summary>
+ /// Sends a number of bytes to the end point of the connection using the specified <see cref="SendOption"/>.
+ /// </summary>
+ /// <param name="msg">The message to send.</param>
+ public abstract SendErrors Send(MessageWriter msg);
+
+ /// <summary>
+ /// Connects the connection to a server and begins listening.
+ /// This method blocks and may thrown if there is a problem connecting.
+ /// </summary>
+ /// <param name="bytes">The bytes of data to send in the handshake.</param>
+ /// <param name="timeout">The number of milliseconds to wait before giving up on the connect attempt.</param>
+ public abstract void Connect(byte[] bytes = null, int timeout = 5000);
+
+ /// <summary>
+ /// Connects the connection to a server and begins listening.
+ /// This method does not block.
+ /// </summary>
+ /// <param name="bytes">The bytes of data to send in the handshake.</param>
+ public abstract void ConnectAsync(byte[] bytes = null);
+
+ /// <summary>
+ /// Invokes the DataReceived event.
+ /// </summary>
+ /// <param name="msg">The bytes received.</param>
+ /// <param name="sendOption">The <see cref="SendOption"/> the message was received with.</param>
+ /// <remarks>
+ /// Invokes the <see cref="DataReceived"/> event on this connection to alert subscribers a new message has been
+ /// received. The bytes and the send option that the message was sent with should be passed in to give to the
+ /// subscribers.
+ /// </remarks>
+ protected void InvokeDataReceived(MessageReader msg, SendOption sendOption)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<DataReceivedEventArgs> handler = DataReceived;
+ if (handler != null)
+ {
+ try
+ {
+ handler(new DataReceivedEventArgs(this, msg, sendOption));
+ }
+ catch { }
+ }
+ else
+ {
+ msg.Recycle();
+ }
+ }
+
+ /// <summary>
+ /// Invokes the Disconnected event.
+ /// </summary>
+ /// <param name="e">The exception, if any, that occurred to cause this.</param>
+ /// <param name="reader">Extra disconnect data</param>
+ /// <remarks>
+ /// Invokes the <see cref="Disconnected"/> event to alert subscribres this connection has been disconnected either
+ /// by the end point or because an error occurred. If an error occurred the error should be passed in in order to
+ /// pass to the subscribers, otherwise null can be passed in.
+ /// </remarks>
+ protected void InvokeDisconnected(string e, MessageReader reader)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ EventHandler<DisconnectedEventArgs> handler = Disconnected;
+ if (handler != null)
+ {
+ DisconnectedEventArgs args = new DisconnectedEventArgs(e, reader);
+ try
+ {
+ handler(this, args);
+ }
+ catch
+ {
+ }
+ }
+ }
+
+ /// <summary>
+ /// For times when you want to force the disconnect handler to fire as well as close it.
+ /// If you only want to close it, just use Dispose.
+ /// </summary>
+ public abstract void Disconnect(string reason, MessageWriter writer = null);
+
+ /// <summary>
+ /// Disposes of this NetworkConnection.
+ /// </summary>
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ /// <summary>
+ /// Disposes of this NetworkConnection.
+ /// </summary>
+ /// <param name="disposing">Are we currently disposing?</param>
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ this.DataReceived = null;
+ this.Disconnected = null;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionListener.cs b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs
new file mode 100644
index 0000000..f952847
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionListener.cs
@@ -0,0 +1,160 @@
+using System;
+using System.Net;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Base class for all connection listeners.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// ConnectionListeners are server side objects that listen for clients and create matching server side connections
+ /// for each client in a similar way to TCP does. These connections should be ready for communication immediately.
+ /// </para>
+ /// <para>
+ /// Each time a client connects the <see cref="NewConnection"/> event will be invoked to alert all subscribers to
+ /// the new connection. A disconnected event is then present on the <see cref="Connection"/> that is passed to the
+ /// subscribers.
+ /// </para>
+ /// </remarks>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class ConnectionListener : IDisposable
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ public readonly ListenerStatistics Statistics = new ListenerStatistics();
+
+ public abstract double AveragePing { get; }
+ public abstract int ConnectionCount { get; }
+ public abstract int SendQueueLength { get; }
+ public abstract int ReceiveQueueLength { get; }
+
+ /// <summary>
+ /// A callback for early connection rejection.
+ /// * Return false to reject connection.
+ /// * A null response is ok, we just won't send anything.
+ /// </summary>
+ public AcceptConnectionCheck AcceptConnection;
+ public delegate bool AcceptConnectionCheck(IPEndPoint endPoint, byte[] input, out byte[] response);
+
+ /// <summary>
+ /// Invoked when a new client connects.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// NewConnection is invoked each time a client connects to the listener. The
+ /// <see cref="NewConnectionEventArgs"/> contains the new <see cref="Connection"/> for communication with this
+ /// client.
+ /// </para>
+ /// <para>
+ /// Hazel may or may not store connections so it is your responsibility to keep track and properly Dispose of
+ /// connections to your server.
+ /// </para>
+ /// <include file="DocInclude/common.xml" path="docs/item[@name='Event_Thread_Safety_Warning']/*" />
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/>
+ /// </example>
+ public event Action<NewConnectionEventArgs> NewConnection;
+
+ /// <summary>
+ /// Invoked when an internal error causes the listener to be unable to continue handling messages.
+ /// </summary>
+ /// <remarks>
+ /// Support for this is still pretty limited. At the time of writing, only iOS devices need this in one case:
+ /// When iOS suspends an app, it might also free our socket while not allowing Unity to run in the background.
+ /// When Unity resumes, it can't know that time passed or the socket is freed, so we used to continuously throw internal errors.
+ /// </remarks>
+ public event Action<HazelInternalErrors> OnInternalError;
+
+ /// <summary>
+ /// Makes this connection listener begin listening for connections.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This instructs the listener to begin listening for new clients connecting to the server. When a new client
+ /// connects the <see cref="NewConnection"/> event will be invoked containing the connection to the new client.
+ /// </para>
+ /// <para>
+ /// To stop listening you should call <see cref="Dispose()"/>.
+ /// </para>
+ /// </remarks>
+ /// <example>
+ /// <code language="C#" source="DocInclude/TcpListenerExample.cs"/>
+ /// </example>
+ public abstract void Start();
+
+ /// <summary>
+ /// Invokes the NewConnection event with the supplied connection.
+ /// </summary>
+ /// <param name="msg">The user sent bytes that were received as part of the handshake.</param>
+ /// <param name="connection">The connection to pass in the arguments.</param>
+ /// <remarks>
+ /// Implementers should call this to invoke the <see cref="NewConnection"/> event before data is received so that
+ /// subscribers do not miss any data that may have been sent immediately after connecting.
+ /// </remarks>
+ protected void InvokeNewConnection(MessageReader msg, Connection connection)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<NewConnectionEventArgs> handler = NewConnection;
+ if (handler != null)
+ {
+ try
+ {
+ handler(new NewConnectionEventArgs(msg, connection));
+ }
+ catch (Exception e)
+ {
+ }
+ }
+ }
+
+
+ /// <summary>
+ /// Invokes the InternalError event with the supplied reason.
+ /// </summary>
+ protected void InvokeInternalError(HazelInternalErrors reason)
+ {
+ // Make a copy to avoid race condition between null check and invocation
+ Action<HazelInternalErrors> handler = this.OnInternalError;
+ if (handler != null)
+ {
+ try
+ {
+ handler(reason);
+ }
+ catch
+ {
+ }
+ }
+ }
+
+ /// <summary>
+ /// Call to dispose of the connection listener.
+ /// </summary>
+ public void Dispose()
+ {
+ Dispose(true);
+ }
+
+ /// <summary>
+ /// Called when the object is being disposed.
+ /// </summary>
+ /// <param name="disposing">Are we disposing?</param>
+ protected virtual void Dispose(bool disposing)
+ {
+ this.NewConnection = null;
+ this.OnInternalError = null;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionState.cs b/Tools/Hazel-Networking/Hazel/ConnectionState.cs
new file mode 100644
index 0000000..5d3f5c9
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionState.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Represents the state a <see cref="Connection"/> is currently in.
+ /// </summary>
+ public enum ConnectionState
+ {
+ /// <summary>
+ /// The Connection has either not been established yet or has been disconnected.
+ /// </summary>
+ NotConnected,
+
+ /// <summary>
+ /// The Connection is currently connecting to an endpoint.
+ /// </summary>
+ Connecting,
+
+ /// <summary>
+ /// The Connection is connected and data can be transfered.
+ /// </summary>
+ Connected,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs
new file mode 100644
index 0000000..f2c3ed9
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ConnectionStatistics.cs
@@ -0,0 +1,574 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Holds statistics about the traffic through a <see cref="Connection"/>.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public class ConnectionStatistics
+ {
+ private const int ExpectedMTU = 1200;
+
+ /// <summary>
+ /// The total number of messages sent.
+ /// </summary>
+ public int MessagesSent
+ {
+ get
+ {
+ return UnreliableMessagesSent + ReliableMessagesSent + FragmentedMessagesSent + AcknowledgementMessagesSent + HelloMessagesSent;
+ }
+ }
+
+ private int packetsSent;
+ public int PacketsSent => this.packetsSent;
+
+ private int reliablePacketsAcknowledged;
+ public int ReliablePacketsAcknowledged => this.reliablePacketsAcknowledged;
+
+ /// <summary>
+ /// The number of messages sent larger than 576 bytes. This is smaller than most default MTUs.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int FragmentableMessagesSent
+ {
+ get
+ {
+ return fragmentableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of messages sent larger than 576 bytes.
+ /// </summary>
+ int fragmentableMessagesSent;
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int UnreliableMessagesSent
+ {
+ get
+ {
+ return unreliableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ int unreliableMessagesSent;
+
+ /// <summary>
+ /// The number of reliable messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of reliable messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogReliableSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int ReliableMessagesSent
+ {
+ get
+ {
+ return reliableMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages sent.
+ /// </summary>
+ int reliableMessagesSent;
+
+ /// <summary>
+ /// The number of fragmented messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of fragmented messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogFragmentedSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int FragmentedMessagesSent
+ {
+ get
+ {
+ return fragmentedMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of fragmented messages sent.
+ /// </summary>
+ int fragmentedMessagesSent;
+
+ /// <summary>
+ /// The number of acknowledgement messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of acknowledgements that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogAcknowledgementSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int AcknowledgementMessagesSent
+ {
+ get
+ {
+ return acknowledgementMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of acknowledgement messages sent.
+ /// </summary>
+ int acknowledgementMessagesSent;
+
+ /// <summary>
+ /// The number of hello messages sent.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were sent from the <see cref="Connection"/>, incremented
+ /// each time that LogHelloSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </remarks>
+ public int HelloMessagesSent
+ {
+ get
+ {
+ return helloMessagesSent;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages sent.
+ /// </summary>
+ int helloMessagesSent;
+
+ /// <summary>
+ /// The number of bytes of data sent.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the number of bytes of data (i.e. user bytes) that were sent from the <see cref="Connection"/>,
+ /// accumulated each time that LogSend is called by the Connection. Messages that caused an error are not
+ /// counted and messages are only counted once all other operations in the send are complete.
+ /// </para>
+ /// <para>
+ /// For the number of bytes including protocol bytes see <see cref="TotalBytesSent"/>.
+ /// </para>
+ /// </remarks>
+ public long DataBytesSent
+ {
+ get
+ {
+ return Interlocked.Read(ref dataBytesSent);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes of data sent.
+ /// </summary>
+ long dataBytesSent;
+
+ /// <summary>
+ /// The number of bytes sent in total.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the total number of bytes (the data bytes plus protocol bytes) that were sent from the
+ /// <see cref="Connection"/>, accumulated each time that LogSend is called by the Connection. Messages that
+ /// caused an error are not counted and messages are only counted once all other operations in the send are
+ /// complete.
+ /// </para>
+ /// <para>
+ /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesSent"/>.
+ /// </para>
+ /// </remarks>
+ public long TotalBytesSent
+ {
+ get
+ {
+ return Interlocked.Read(ref totalBytesSent);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes sent in total.
+ /// </summary>
+ long totalBytesSent;
+
+ /// <summary>
+ /// The total number of messages received.
+ /// </summary>
+ public int MessagesReceived
+ {
+ get
+ {
+ return UnreliableMessagesReceived + ReliableMessagesReceived + FragmentedMessagesReceived + AcknowledgementMessagesReceived + helloMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of unreliable messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogUnreliableReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int UnreliableMessagesReceived
+ {
+ get
+ {
+ return unreliableMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of unreliable messages received.
+ /// </summary>
+ int unreliableMessagesReceived;
+
+ /// <summary>
+ /// The number of reliable messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of reliable messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogReliableReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int ReliableMessagesReceived
+ {
+ get
+ {
+ return reliableMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of reliable messages received.
+ /// </summary>
+ int reliableMessagesReceived;
+
+ /// <summary>
+ /// The number of fragmented messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of fragmented messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogFragmentedReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int FragmentedMessagesReceived
+ {
+ get
+ {
+ return fragmentedMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of fragmented messages received.
+ /// </summary>
+ int fragmentedMessagesReceived;
+
+ /// <summary>
+ /// The number of acknowledgement messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of acknowledgement messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogAcknowledgemntReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int AcknowledgementMessagesReceived
+ {
+ get
+ {
+ return acknowledgementMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of acknowledgement messages received.
+ /// </summary>
+ int acknowledgementMessagesReceived;
+
+ /// <summary>
+ /// The number of ping messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int PingMessagesReceived
+ {
+ get
+ {
+ return pingMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ int pingMessagesReceived;
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ /// <remarks>
+ /// This is the number of hello messages that were received by the <see cref="Connection"/>, incremented
+ /// each time that LogHelloReceive is called by the Connection. Messages are counted before the receive event is invoked.
+ /// </remarks>
+ public int HelloMessagesReceived
+ {
+ get
+ {
+ return helloMessagesReceived;
+ }
+ }
+
+ /// <summary>
+ /// The number of hello messages received.
+ /// </summary>
+ int helloMessagesReceived;
+
+ /// <summary>
+ /// The number of bytes of data received.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the number of bytes of data (i.e. user bytes) that were received by the <see cref="Connection"/>,
+ /// accumulated each time that LogReceive is called by the Connection. Messages are counted before the receive
+ /// event is invoked.
+ /// </para>
+ /// <para>
+ /// For the number of bytes including protocol bytes see <see cref="TotalBytesReceived"/>.
+ /// </para>
+ /// </remarks>
+ public long DataBytesReceived
+ {
+ get
+ {
+ return Interlocked.Read(ref dataBytesReceived);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes of data received.
+ /// </summary>
+ long dataBytesReceived;
+
+ /// <summary>
+ /// The number of bytes received in total.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// This is the total number of bytes (the data bytes plus protocol bytes) that were received by the
+ /// <see cref="Connection"/>, accumulated each time that LogReceive is called by the Connection. Messages are
+ /// counted before the receive event is invoked.
+ /// </para>
+ /// <para>
+ /// For the number of data bytes excluding protocol bytes see <see cref="DataBytesReceived"/>.
+ /// </para>
+ /// </remarks>
+ public long TotalBytesReceived
+ {
+ get
+ {
+ return Interlocked.Read(ref totalBytesReceived);
+ }
+ }
+
+ /// <summary>
+ /// The number of bytes received in total.
+ /// </summary>
+ long totalBytesReceived;
+
+ public int MessagesResent { get { return messagesResent; } }
+ int messagesResent;
+
+ /// <summary>
+ /// Logs the sending of an unreliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogUnreliableSend(int dataLength)
+ {
+ Interlocked.Increment(ref unreliableMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+
+ }
+
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ internal void LogPacketSend(int totalLength)
+ {
+ Interlocked.Increment(ref this.packetsSent);
+ Interlocked.Add(ref totalBytesSent, totalLength);
+
+ if (totalLength > ExpectedMTU)
+ {
+ Interlocked.Increment(ref fragmentableMessagesSent);
+ }
+ }
+
+ /// <summary>
+ /// Logs the sending of a reliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogReliableSend(int dataLength)
+ {
+ Interlocked.Increment(ref reliableMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+ }
+
+ /// <summary>
+ /// Logs the sending of a fragmented data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data sent.</param>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogFragmentedSend(int dataLength)
+ {
+ Interlocked.Increment(ref fragmentedMessagesSent);
+ Interlocked.Add(ref dataBytesSent, dataLength);
+ }
+
+ /// <summary>
+ /// Logs the sending of a acknowledgement data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogAcknowledgementSend()
+ {
+ Interlocked.Increment(ref acknowledgementMessagesSent);
+ }
+
+ /// <summary>
+ /// Logs the sending of a hellp data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes sent.</param>
+ /// <remarks>
+ /// This should be called after the data has been sent and should only be called for data that is sent sucessfully.
+ /// </remarks>
+ internal void LogHelloSend()
+ {
+ Interlocked.Increment(ref helloMessagesSent);
+ }
+
+ /// <summary>
+ /// Logs the receiving of an unreliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogUnreliableReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref unreliableMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a reliable data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogReliableReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref reliableMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a fragmented data packet in the statistics.
+ /// </summary>
+ /// <param name="dataLength">The number of bytes of data received.</param>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogFragmentedReceive(int dataLength, int totalLength)
+ {
+ Interlocked.Increment(ref fragmentedMessagesReceived);
+ Interlocked.Add(ref dataBytesReceived, dataLength);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of an acknowledgement data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogAcknowledgementReceive(int totalLength)
+ {
+ Interlocked.Increment(ref acknowledgementMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the unique acknowledgement of a ping or reliable data packet.
+ /// </summary>
+ internal void LogReliablePacketAcknowledged()
+ {
+ Interlocked.Increment(ref this.reliablePacketsAcknowledged);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a hello data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogPingReceive(int totalLength)
+ {
+ Interlocked.Increment(ref pingMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ /// <summary>
+ /// Logs the receiving of a hello data packet in the statistics.
+ /// </summary>
+ /// <param name="totalLength">The total number of bytes received.</param>
+ /// <remarks>
+ /// This should be called before the received event is invoked so it is up to date for subscribers to that event.
+ /// </remarks>
+ internal void LogHelloReceive(int totalLength)
+ {
+ Interlocked.Increment(ref helloMessagesReceived);
+ Interlocked.Add(ref totalBytesReceived, totalLength);
+ }
+
+ internal void LogMessageResent()
+ {
+ Interlocked.Increment(ref messagesResent);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs
new file mode 100644
index 0000000..bfbbc01
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/AesGcm.cs
@@ -0,0 +1,369 @@
+using System;
+using System.Diagnostics;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// Implementation of AEAD_AES128_GCM based on:
+ /// * RFC 5116 [1]
+ /// * NIST SP 800-38d [2]
+ ///
+ /// [1] https://tools.ietf.org/html/rfc5116
+ /// [2] https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
+ ///
+ /// Adapted from: https://gist.github.com/mendsley/777e6bd9ae7eddcb2b0c0fe18247dc60
+ /// </summary>
+ public class Aes128Gcm : IDisposable
+ {
+ public const int KeySize = 16;
+ public const int NonceSize = 12;
+ public const int CiphertextOverhead = TagSize;
+
+ private const int TagSize = 16;
+
+ private readonly IAes encryptor_;
+
+ private readonly ByteSpan hashSubkey_;
+ private readonly ByteSpan blockJ_;
+ private readonly ByteSpan blockS_;
+ private readonly ByteSpan blockZ_;
+ private readonly ByteSpan blockV_;
+ private readonly ByteSpan blockScratch_;
+
+ /// <summary>
+ /// Creates a new instance of an AEAD_AES128_GCM cipher
+ /// </summary>
+ /// <param name="key">Symmetric key</param>
+ public Aes128Gcm(ByteSpan key)
+ {
+ if (key.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid key length", nameof(key));
+ }
+
+ // Create the AES block cipher
+ this.encryptor_ = CryptoProvider.CreateAes(key);
+
+ // Allocate scratch space
+ ByteSpan scratchSpace = new byte[96];
+ this.hashSubkey_ = scratchSpace.Slice(0, 16);
+ this.blockJ_ = scratchSpace.Slice(16, 16);
+ this.blockS_ = scratchSpace.Slice(32, 16);
+ this.blockZ_ = scratchSpace.Slice(48, 16);
+ this.blockV_ = scratchSpace.Slice(64, 16);
+ this.blockScratch_ = scratchSpace.Slice(80, 16);
+
+ // Create the GHASH subkey by encrypting the 0-block
+ this.encryptor_.EncryptBlock(this.hashSubkey_, this.hashSubkey_);
+ }
+
+ /// <summary>
+ /// Encryptes the specified plaintext and generates an authentication
+ /// tag for the provided additional data. Returns the byte array
+ /// containg both the ciphertext and authentication tag.
+ /// </summary>
+ /// <param name="output">
+ /// Array in which to encode the encrypted ciphertext and
+ /// authentication tag. This array must be large enough to hold
+ /// `plaintext.Lengh + CiphertextOverhead` bytes.
+ /// </param>
+ /// <param name="nonce">Unique value for this message</param>
+ /// <param name="plaintext">Plaintext data to encrypt</param>
+ /// <param name="associatedData">
+ /// Additional data used to authenticate the message
+ /// </param>
+ public void Seal(ByteSpan output, ByteSpan nonce, ByteSpan plaintext, ByteSpan associatedData)
+ {
+ if (nonce.Length != NonceSize)
+ {
+ throw new ArgumentException("Invalid nonce size", nameof(nonce));
+ }
+ if (output.Length < plaintext.Length + CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+
+ // Create the initial counter block
+ nonce.CopyTo(this.blockJ_);
+
+ // Encrypt the plaintext to output
+ GCTR(output, this.blockJ_, 2, plaintext);
+
+ // Generate and append the authentication tag
+ int tagOffset = plaintext.Length;
+ GenerateAuthenticationTag(output.Slice(tagOffset), output.Slice(0, tagOffset), associatedData);
+ }
+
+ /// <summary>
+ /// Validates the authentication tag against the provided additional
+ /// data, then decrypts the cipher text returning the original
+ /// plaintext.
+ /// </summary>
+ /// <param name="nonce">
+ /// The unique value used to seal this message
+ /// </param>
+ /// <param name="ciphertext">
+ /// Combined ciphertext and authentication tag
+ /// </param>
+ /// <param name="associatedData">
+ /// Additional data used to authenticate the message
+ /// </param>
+ /// <param name="output">
+ /// On successful validation and decryprion, Open writes the original
+ /// plaintext to output. Must contain enough space to hold
+ /// `ciphertext.Length - CiphertextOverhead` bytes.
+ /// </param>
+ /// <returns>
+ /// True if the data was validated and successfully decrypted.
+ /// Otherwise, false.
+ /// </returns>
+ public bool Open(ByteSpan output, ByteSpan nonce, ByteSpan ciphertext, ByteSpan associatedData)
+ {
+ if (nonce.Length != NonceSize)
+ {
+ throw new ArgumentException("Invalid nonce size", nameof(nonce));
+ }
+ if (ciphertext.Length < CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid ciphertext size", nameof(ciphertext));
+ }
+ else if (output.Length < ciphertext.Length - CiphertextOverhead)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+
+ // Split ciphertext into actual ciphertext and authentication
+ // tag components.
+ ByteSpan authenticationTag = ciphertext.Slice(ciphertext.Length - TagSize);
+ ciphertext = ciphertext.Slice(0, ciphertext.Length - TagSize);
+
+ // Create the initial counter block
+ nonce.CopyTo(this.blockJ_);
+
+ // Verify the tags match
+ GenerateAuthenticationTag(this.blockScratch_, ciphertext, associatedData);
+ if (0 == Const.ConstantCompareSpans(this.blockScratch_, authenticationTag))
+ {
+ return false;
+ }
+
+ // Decrypt the cipher text to output
+ GCTR(output, this.blockJ_, 2, ciphertext);
+ return true;
+ }
+
+ /// <summary>
+ /// Release resources acquired by the cipher
+ /// </summary>
+ public void Dispose()
+ {
+ this.encryptor_.Dispose();
+ }
+
+ // Generate the authentication tag for a ciphertext+associated data
+ void GenerateAuthenticationTag(ByteSpan output, ByteSpan ciphertext, ByteSpan associatedData)
+ {
+ Debug.Assert(output.Length >= 16);
+
+ // Hash `Associated data || Ciphertext || len(AssociatedD data) || len(Ciphertext)`
+ // into `blockS`
+ {
+ // Clear hash output block
+ SetSpanToZeros(this.blockS_);
+
+ // Write associated data blocks to hash
+ int fullBlocks = associatedData.Length / 16;
+ GHASH(this.blockS_, associatedData, fullBlocks);
+ if (fullBlocks * 16 < associatedData.Length)
+ {
+ SetSpanToZeros(this.blockScratch_);
+ associatedData.Slice(fullBlocks * 16).CopyTo(this.blockScratch_);
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Write ciphertext blocks to hash
+ fullBlocks = ciphertext.Length / 16;
+ GHASH(this.blockS_, ciphertext, fullBlocks);
+ if (fullBlocks * 16 < ciphertext.Length)
+ {
+ SetSpanToZeros(this.blockScratch_);
+ ciphertext.Slice(fullBlocks * 16).CopyTo(this.blockScratch_);
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Write bit sizes to hash
+ ulong associatedDataLengthInBits = (ulong)(8 * associatedData.Length);
+ ulong ciphertextDataLengthInBits = (ulong)(8 * ciphertext.Length);
+ this.blockScratch_.WriteBigEndian64(associatedDataLengthInBits);
+ this.blockScratch_.WriteBigEndian64(ciphertextDataLengthInBits, 8);
+
+ GHASH(this.blockS_, this.blockScratch_, 1);
+ }
+
+ // Encrypt the tag. GCM requires this because `GASH` is not
+ // cryptographically secure. An attacker could derive our hash
+ // subkey `hashSubkey_` from an unencrypted tag.
+ GCTR(output, this.blockJ_, 1, this.blockS_);
+ }
+
+ // Run the GCTR cipher
+ void GCTR(ByteSpan output, ByteSpan counterBlock, uint counter, ByteSpan data)
+ {
+ Debug.Assert(counterBlock.Length == 16);
+ Debug.Assert(output.Length >= data.Length);
+
+ // Loop through plaintext blocks
+ int writeIndex = 0;
+ int numBlocks = (data.Length + 15) / 16;
+ for (int ii = 0; ii != numBlocks; ++ii)
+ {
+ // Encode counter into block
+ // CB[1] = J0
+ // CB[i] = inc[32](CB[i-1])
+ counterBlock.WriteBigEndian32(counter, 12);
+ ++counter;
+
+ // CIPH[k](CB[i])
+ this.encryptor_.EncryptBlock(counterBlock.Slice(0, 16), this.blockScratch_);
+
+ // Y[i] = X[i] xor CIPH[k](CB[i])
+ for (int jj = 0; jj != 16 && writeIndex < data.Length; ++jj, ++writeIndex)
+ {
+ output[writeIndex] = (byte)(data[writeIndex] ^ this.blockScratch_[jj]);
+ }
+ }
+ }
+
+ // Run the GHASH function
+ void GHASH(ByteSpan output, ByteSpan data, int numBlocks)
+ {
+ ///TODO(mendsley): See Ref[6] for opitmizations of GHASH on both hardware and software
+ ///
+ ///[6] D. McGrew, J. Viega, The Galois/Counter Mode of Operation (GCM), Natl. Inst. Stand.
+ ///Technol. [Web page], http://www.csrc.nist.gov/groups/ST/toolkit/BCM/documents/
+ ///proposedmodes / gcm / gcm - revised - spec.pdf, May 31, 2005.
+
+ Debug.Assert(output.Length == 16);
+ Debug.Assert(data.Length >= numBlocks * 16);
+
+ int readIndex = 0;
+ for (int ii = 0; ii != numBlocks; ++ii)
+ {
+ for (int jj = 0; jj != 16; ++jj, ++readIndex)
+ {
+ // Y[ii-1] xor X[ii]
+ output[jj] ^= data[readIndex];
+ }
+
+ // Y[ii] = (Y[ii-1] xor X[ii]) · H
+ MultiplyGF128Elements(output, this.hashSubkey_, this.blockZ_, this.blockV_);
+ }
+ }
+
+ // Multiply two Galois field elements `X` and `Y` together and store
+ // the result in `X` such that at the end of the function:
+ // X = X·Y
+ static void MultiplyGF128Elements(ByteSpan X, ByteSpan Y, ByteSpan scratchZ, ByteSpan scratchV)
+ {
+ Debug.Assert(X.Length == 16);
+ Debug.Assert(Y.Length == 16);
+ Debug.Assert(scratchZ.Length == 16);
+ Debug.Assert(scratchV.Length == 16);
+
+ // Galois (finite) fields represented by GF(p) define a set of
+ // closed algebraic operations. For AES128_GCM we'll be dealing
+ // with the GF(2^128) field.
+ //
+ // We treat each incoming 16 byte block as a polynomial in field
+ // and define multiplication between two polynomials as the
+ // polynomial product reduced by (mod) the field polynomial:
+ // 1 + x + x^2 + x^7 + x^128
+ //
+ // Field polynomials are represented by a 128 bit string. Bit n is
+ // the coefficient of the x^n term. We use little-endian bit
+ // ordering (not to be confused with byte ordering) for these
+ // coefficients. E.g. X[0] & 0x00000001 represents the 7th bit in
+ // the bit string defined by X, _not_ the 0th bit.
+ //
+
+ // What follows is a modified version of the "peasant's algorithm"
+ // to multiply two numbers:
+ //
+ // Z contains the accumulated product
+ // V is a copy of Y (so we can modify it via shifting).
+ //
+ // We calculate Z = X·V as follows
+ // We loop through each of the 128 bits in X maintaining the
+ // following loop invariant: X·V + Z = the final product
+ //
+ // On each iteration `ii`:
+ //
+ // If the `ii`th bit of `X` is set, add the add the polynomial
+ // in `V` to `X`: `X[n] = X[n] ^ V[n]`
+ //
+ // Double V (Shift one bit right since we're storing little
+ // endian bit). This has the effect of multiplying V by the
+ // polynomial `x`. We track the unrepresentable coefficient
+ // of `x^128` by storing the most significant bit before the
+ // shift `V[15] >> 7` as `carry`
+ //
+ // Check if we've overflowed our multiplication. If overflow
+ // occurred, there will be a non-zero coefficient for the
+ // `x^128` term in the step above `carry`
+ //
+ // If we have overflowed, our polynomial is exactly of degree
+ // 129 (since we're only multiplying by `x`). We reduce the
+ // polynomial back into degree 128 by adding our field's
+ // irreducible polynomial: 1 + x + x^2 + x^7 + x^128. This
+ // reduction cancels out the x^128 term (x^128 + x^128 in GF(2)
+ // is zero). Therefore this modulo can be achieved by simply
+ // adding the irreducible polynomial to the new value of `V`. The
+ // irreducible polynomial is represented by the bit string:
+ // `11100001` followed by 120 `0`s. We can add this value to `V`
+ // by: `V[0] = V[0] ^ 0xE1`.
+ SetSpanToZeros(scratchZ);
+ X.CopyTo(scratchV);
+
+ for (int ii = 0; ii != 128; ++ii)
+ {
+ int bitIndex = 7 - (ii % 8);
+ if ((Y[ii / 8] & (1 << bitIndex)) != 0)
+ {
+ for (int jj = 0; jj != 16; ++jj)
+ {
+ scratchZ[jj] ^= scratchV[jj];
+ }
+ }
+
+ bool carry = false;
+ for (int jj = 0; jj != 16; ++jj)
+ {
+ bool newCarry = (scratchV[jj] & 0x01) != 0;
+ scratchV[jj] >>= 1;
+ if (carry)
+ {
+ scratchV[jj] |= 0x80;
+ }
+ carry = newCarry;
+ }
+
+ if (carry)
+ {
+ scratchV[0] ^= 0xE1;
+ }
+ }
+
+ scratchZ.CopyTo(X);
+ }
+
+ // Set the contents of a span to all zero
+ static void SetSpanToZeros(ByteSpan span)
+ {
+ for (int ii = 0, nn = span.Length; ii != nn; ++ii)
+ {
+ span[ii] = 0;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Const.cs b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs
new file mode 100644
index 0000000..4dfef47
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/Const.cs
@@ -0,0 +1,82 @@
+using System.Diagnostics;
+
+namespace Hazel.Crypto
+{
+ public static class Const
+ {
+
+ /// <summary>
+ /// Compare two bytes for equality.
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// This is aceived by collapsing the xor bits down into a single bit.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` is the two bytes or equivalent. Otherwise, returns `0`
+ /// </returns>
+ public static byte ConstantCompareByte(byte a, byte b)
+ {
+ byte result = (byte)(~(a ^ b));
+
+ // collapse bits down to the LSB
+ result &= (byte)(result >> 4);
+ result &= (byte)(result >> 2);
+ result &= (byte)(result >> 1);
+
+ return result;
+ }
+
+ /// <summary>
+ /// Compare two equal length spans for equality.
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` if the spans are equivalent. Others, returns `0`.
+ /// </returns>
+ public static byte ConstantCompareSpans(ByteSpan a, ByteSpan b)
+ {
+ Debug.Assert(a.Length == b.Length);
+
+ byte value = 0;
+ for (int ii = 0, nn = a.Length; ii != nn; ++ii)
+ {
+ value |= (byte)(a[ii] ^ b[ii]);
+ }
+
+ return ConstantCompareByte(value, 0);
+ }
+
+ /// <summary>
+ /// Compare a span against an all zero span
+ ///
+ /// This takes care to always use a constant amount of time to prevent
+ /// leaking information through side-channel attacks.
+ ///
+ /// Ported from:
+ /// https://github.com/mendsley/tiny/blob/master/include/tiny/crypto/constant.h
+ /// </summary>
+ /// <returns>
+ /// Returns `1` if the spans is all zeros. Others, returns `0`.
+ /// </returns>
+ public static byte ConstantCompareZeroSpan(ByteSpan a)
+ {
+ byte value = 0;
+ for (int ii = 0, nn = a.Length; ii != nn; ++ii)
+ {
+ value |= (byte)(a[ii] ^ 0);
+ }
+
+ return ConstantCompareByte(value, 0);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs
new file mode 100644
index 0000000..2c56c70
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/CryptoProvider.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Crypto
+{
+ public static class CryptoProvider
+ {
+ public delegate IAes CreateAesOverrideDelegate(ByteSpan key);
+
+ /// <summary>
+ /// Override the default AES creation function
+ /// </summary>
+ public static CreateAesOverrideDelegate OverrideCreateAes = null;
+
+ /// <summary>
+ /// Create a new AES cipher
+ /// </summary>
+ /// <param name="key">Encrtyption key</param>
+ public static IAes CreateAes(ByteSpan key)
+ {
+ if (OverrideCreateAes != null)
+ {
+ IAes result = OverrideCreateAes(key);
+ if (null != result)
+ {
+ return result;
+ }
+ }
+
+ return new DefaultAes(key);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs
new file mode 100644
index 0000000..da72fb8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/DefaultAes.cs
@@ -0,0 +1,49 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// AES provider using the default System.Security.Cryptography implementation
+ /// </summary>
+ public class DefaultAes : IAes
+ {
+ private readonly ICryptoTransform encryptor_;
+
+ /// <summary>
+ /// Create a new default instance of the AES block cipher
+ /// </summary>
+ /// <param name="key">Encryption key</param>
+ public DefaultAes(ByteSpan key)
+ {
+ // Create the AES block cipher
+ using (Aes aes = Aes.Create())
+ {
+ aes.KeySize = key.Length * 8;
+ aes.BlockSize = aes.KeySize;
+ aes.Mode = CipherMode.ECB;
+ aes.Padding = PaddingMode.Zeros;
+ aes.Key = key.ToArray();
+
+ this.encryptor_ = aes.CreateEncryptor();
+ }
+ }
+
+ /// <inheritdoc/>
+ public void Dispose()
+ {
+ this.encryptor_.Dispose();
+ }
+
+ /// <inheritdoc/>
+ public int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan)
+ {
+ if (inputSpan.Length != outputSpan.Length)
+ {
+ throw new ArgumentException($"ouputSpan length ({outputSpan.Length}) does not match inputSpan length ({inputSpan.Length})", nameof(outputSpan));
+ }
+
+ return this.encryptor_.TransformBlock(inputSpan.GetUnderlyingArray(), inputSpan.Offset, inputSpan.Length, outputSpan.GetUnderlyingArray(), outputSpan.Offset);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs
new file mode 100644
index 0000000..6c494cd
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/IAes.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// AES encryption interface
+ /// </summary>
+ public interface IAes : IDisposable
+ {
+ /// <summary>
+ /// Encrypts the specified region of the input byte array and copies
+ /// the resulting transform to the specified region of the output
+ /// array.
+ /// </summary>
+ /// <param name="inputSpan">The input for which to encrypt</param>
+ /// <param name="outputSpan">
+ /// The otput to which to write the encrypted data. This span can
+ /// overlap with `inputSpan`.
+ /// </param>
+ /// <returns>The number of bytes written</returns>
+ int EncryptBlock(ByteSpan inputSpan, ByteSpan outputSpan);
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs
new file mode 100644
index 0000000..1903693
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/Sha256Stream.cs
@@ -0,0 +1,86 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// Streams data into a SHA256 digest
+ /// </summary>
+ public class Sha256Stream : IDisposable
+ {
+ /// <summary>
+ /// Size of the SHA256 digest in bytes
+ /// </summary>
+ public const int DigestSize = 32;
+
+ private SHA256 hash = SHA256.Create();
+ private bool isHashFinished = false;
+
+ struct EmptyArray
+ {
+ public static readonly byte[] Value = new byte[0];
+ }
+
+ /// <summary>
+ /// Create a new instance of a SHA256 stream
+ /// </summary>
+ public Sha256Stream()
+ {
+ }
+
+ /// <summary>
+ /// Release resources associated with the stream
+ /// </summary>
+ public void Dispose()
+ {
+ this.hash?.Dispose();
+ this.hash = null;
+
+ GC.SuppressFinalize(this);
+ }
+
+ /// <summary>
+ /// Reset the stream to its initial state
+ /// </summary>
+ public void Reset()
+ {
+ this.hash?.Dispose();
+ this.hash = SHA256.Create();
+ this.isHashFinished = false;
+ }
+
+ /// <summary>
+ /// Add data to the stream
+ /// </summary>
+ public void AddData(ByteSpan data)
+ {
+ while (data.Length > 0)
+ {
+ int offset = this.hash.TransformBlock(data.GetUnderlyingArray(), data.Offset, data.Length, null, 0);
+ data = data.Slice(offset);
+ }
+ }
+
+ /// <summary>
+ /// Calculate the final hash of the stream data
+ /// </summary>
+ /// <param name="output">
+ /// Target span to which the hash will be written
+ /// </param>
+ public void CopyOrCalculateFinalHash(ByteSpan output)
+ {
+ if (output.Length != DigestSize)
+ {
+ throw new ArgumentException($"Expected a span of {DigestSize} bytes. Got a span of {output.Length} bytes", nameof(output));
+ }
+
+ if (this.isHashFinished == false)
+ {
+ this.hash.TransformFinalBlock(EmptyArray.Value, 0, 0);
+ this.isHashFinished = true;
+ }
+
+ new ByteSpan(this.hash.Hash).CopyTo(output);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs
new file mode 100644
index 0000000..03164ec
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/SpanCryptoExtensions.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Security.Cryptography;
+
+namespace Hazel.Crypto
+{
+ public static class SpanCryptoExtensions
+ {
+ /// <summary>
+ /// Clear a span's contents to zero
+ /// </summary>
+ public static void SecureClear(this ByteSpan span)
+ {
+ if (span.Length > 0)
+ {
+ Array.Clear(span.GetUnderlyingArray(), span.Offset, span.Length);
+ }
+ }
+
+ /// <summary>
+ /// Fill a byte span with random data
+ /// </summary>
+ /// <param name="random">Entropy source</param>
+ public static void FillWithRandom(this ByteSpan span, RandomNumberGenerator random)
+ {
+ if (span.Offset == 0 && span.Length == span.GetUnderlyingArray().Length)
+ {
+ random.GetBytes(span.GetUnderlyingArray());
+ return;
+ }
+
+ byte[] temp = new byte[span.Length];
+ random.GetBytes(temp);
+ new ByteSpan(temp).CopyTo(span);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs
new file mode 100644
index 0000000..3f4624b
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Crypto/X25519.cs
@@ -0,0 +1,844 @@
+using System;
+using System.Diagnostics;
+
+namespace Hazel.Crypto
+{
+ /// <summary>
+ /// The x25519 key agreement algorithm
+ /// </summary>
+ public static class X25519
+ {
+ public const int KeySize = 32;
+
+ /// <summary>
+ /// Element in the GF(2^255 - 19) field
+ /// </summary>
+ public partial struct FieldElement
+ {
+ public int x0, x1, x2, x3, x4;
+ public int x5, x6, x7, x8, x9;
+ };
+
+ private static readonly byte[] BasePoint = {9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+ /// <summary>
+ /// Performs the core x25519 function: Multiplying an EC point by a scalar value
+ /// </summary>
+ public static bool Func(ByteSpan output, ByteSpan scalar, ByteSpan point)
+ {
+ InternalFunc(output, scalar, point);
+ if (Const.ConstantCompareZeroSpan(output) == 1)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Multiplies the base x25519 point by the provided scalar value
+ /// </summary>
+ public static void Func(ByteSpan output, ByteSpan scalar)
+ {
+ InternalFunc(output, scalar, BasePoint);
+ }
+
+ // The FieldElement code below is ported from the original
+ // public domain reference implemtation of X25519
+ // by D. J. Bernstien
+ //
+ // See: https://cr.yp.to/ecdh.html
+
+ private static void InternalFunc(ByteSpan output, ByteSpan scalar, ByteSpan point)
+ {
+ if (output.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid output size", nameof(output));
+ }
+ else if (scalar.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid scalar size", nameof(scalar));
+ }
+ else if (point.Length != KeySize)
+ {
+ throw new ArgumentException("Invalid point size", nameof(point));
+ }
+
+ // copy the scalar so we can properly mask it
+ ByteSpan maskedScalar = new byte[32];
+ scalar.CopyTo(maskedScalar);
+ maskedScalar[0] &= 248;
+ maskedScalar[31] &= 127;
+ maskedScalar[31] |= 64;
+
+ FieldElement x1 = FieldElement.FromBytes(point);
+ FieldElement x2 = FieldElement.One();
+ FieldElement x3 = x1;
+ FieldElement z2 = FieldElement.Zero();
+ FieldElement z3 = FieldElement.One();
+
+ FieldElement tmp0 = new FieldElement();
+ FieldElement tmp1 = new FieldElement();
+
+ int swap = 0;
+ for (int pos = 254; pos >= 0; --pos)
+ {
+ int b = (int)maskedScalar[pos / 8] >> (int)(pos % 8);
+ b &= 1;
+ swap ^= b;
+
+ FieldElement.ConditionalSwap(ref x2, ref x3, swap);
+ FieldElement.ConditionalSwap(ref z2, ref z3, swap);
+ swap = b;
+
+ FieldElement.Sub(ref tmp0, ref x3, ref z3);
+ FieldElement.Sub(ref tmp1, ref x2, ref z2);
+ FieldElement.Add(ref x2, ref x2, ref z2);
+ FieldElement.Add(ref z2, ref x3, ref z3);
+ FieldElement.Multiply(ref z3, ref tmp0, ref x2);
+ FieldElement.Multiply(ref z2, ref z2, ref tmp1);
+ FieldElement.Square(ref tmp0, ref tmp1);
+ FieldElement.Square(ref tmp1, ref x2);
+ FieldElement.Add(ref x3, ref z3, ref z2);
+ FieldElement.Sub(ref z2, ref z3, ref z2);
+ FieldElement.Multiply(ref x2, ref tmp1, ref tmp0);
+ FieldElement.Sub(ref tmp1, ref tmp1, ref tmp0);
+ FieldElement.Square(ref z2, ref z2);
+ FieldElement.Multiply121666(ref z3, ref tmp1);
+ FieldElement.Square(ref x3, ref x3);
+ FieldElement.Add(ref tmp0, ref tmp0, ref z3);
+ FieldElement.Multiply(ref z3, ref x1, ref z2);
+ FieldElement.Multiply(ref z2, ref tmp1, ref tmp0);
+ }
+
+ FieldElement.ConditionalSwap(ref x2, ref x3, swap);
+ FieldElement.ConditionalSwap(ref z2, ref z3, swap);
+
+ FieldElement.Invert(ref z2, ref z2);
+ FieldElement.Multiply(ref x2, ref x2, ref z2);
+ x2.CopyTo(output);
+ }
+
+
+ /// <summary>
+ /// Mathematical operators over GF(2^255 - 19)
+ /// </summary>
+ partial struct FieldElement
+ {
+ /// <summary>
+ /// Convert a byte array to a field element
+ /// </summary>
+ public static FieldElement FromBytes(ByteSpan bytes)
+ {
+ Debug.Assert(bytes.Length >= KeySize);
+
+ long tmp0 = (long)bytes.ReadLittleEndian32();
+ long tmp1 = (long)bytes.ReadLittleEndian24(4) << 6;
+ long tmp2 = (long)bytes.ReadLittleEndian24(7) << 5;
+ long tmp3 = (long)bytes.ReadLittleEndian24(10) << 3;
+ long tmp4 = (long)bytes.ReadLittleEndian24(13) << 2;
+ long tmp5 = (long)bytes.ReadLittleEndian32(16);
+ long tmp6 = (long)bytes.ReadLittleEndian24(20) << 7;
+ long tmp7 = (long)bytes.ReadLittleEndian24(23) << 5;
+ long tmp8 = (long)bytes.ReadLittleEndian24(26) << 4;
+ long tmp9 = (long)(bytes.ReadLittleEndian24(29) & 0x007FFFFF) << 2;
+
+ long carry9 = (tmp9 + (1L<<24)) >> 25;
+ tmp0 += carry9 * 19;
+ tmp9 -= carry9 << 25;
+ long carry1 = (tmp1 + (1L<<24)) >> 25;
+ tmp2 += carry1;
+ tmp1 -= carry1 << 25;
+ long carry3 = (tmp3 + (1L<<24)) >> 25;
+ tmp4 += carry3;
+ tmp3 -= carry3 << 25;
+ long carry5 = (tmp5 + (1L<<24)) >> 25;
+ tmp6 += carry5;
+ tmp5 -= carry5 << 25;
+ long carry7 = (tmp7 + (1L<<24)) >> 25;
+ tmp8 += carry7;
+ tmp7 -= carry7 << 25;
+
+ long carry0 = (tmp0 + (1L<<25)) >> 26;
+ tmp1 += carry0;
+ tmp0 -= carry0 << 26;
+ long carry2 = (tmp2 + (1L<<25)) >> 26;
+ tmp3 += carry2;
+ tmp2 -= carry2 << 26;
+ long carry4 = (tmp4 + (1L<<25)) >> 26;
+ tmp5 += carry4;
+ tmp4 -= carry4 << 26;
+ long carry6 = (tmp6 + (1L<<25)) >> 26;
+ tmp7 += carry6;
+ tmp6 -= carry6 << 26;
+ long carry8 = (tmp8 + (1L<<25)) >> 26;
+ tmp9 += carry8;
+ tmp8 -= carry8 << 26;
+
+ return new FieldElement
+ {
+ x0 = (int)tmp0,
+ x1 = (int)tmp1,
+ x2 = (int)tmp2,
+ x3 = (int)tmp3,
+ x4 = (int)tmp4,
+ x5 = (int)tmp5,
+ x6 = (int)tmp6,
+ x7 = (int)tmp7,
+ x8 = (int)tmp8,
+ x9 = (int)tmp9,
+ };
+ }
+
+ /// <summary>
+ /// Convert the field element to a byte array
+ /// </summary>
+ public void CopyTo(ByteSpan output)
+ {
+ Debug.Assert(output.Length >= 32);
+
+ long q = (19 * this.x9 + (1L << 24)) >> 25;
+ q = ((long)this.x0 + q) >> 26;
+ q = ((long)this.x1 + q) >> 25;
+ q = ((long)this.x2 + q) >> 26;
+ q = ((long)this.x3 + q) >> 25;
+ q = ((long)this.x4 + q) >> 26;
+ q = ((long)this.x5 + q) >> 25;
+ q = ((long)this.x6 + q) >> 26;
+ q = ((long)this.x7 + q) >> 25;
+ q = ((long)this.x8 + q) >> 26;
+ q = ((long)this.x9 + q) >> 25;
+
+ this.x0 = (int)((long)this.x0 + (19L * q));
+
+ int carry0 = (int)(this.x0 >> 26);
+ this.x1 = (int)((int)this.x1 + carry0);
+ this.x0 = (int)((int)this.x0 - (carry0 << 26));
+ int carry1 = (int)(this.x1 >> 25);
+ this.x2 = (int)((int)this.x2 + carry1);
+ this.x1 = (int)((int)this.x1 - (carry1 << 25));
+ int carry2 = (int)(this.x2 >> 26);
+ this.x3 = (int)((int)this.x3 + carry2);
+ this.x2 = (int)((int)this.x2 - (carry2 << 26));
+ int carry3 = (int)(this.x3 >> 25);
+ this.x4 = (int)((int)this.x4 + carry3);
+ this.x3 = (int)((int)this.x3 - (carry3 << 25));
+ int carry4 = (int)(this.x4 >> 26);
+ this.x5 = (int)((int)this.x5 + carry4);
+ this.x4 = (int)((int)this.x4 - (carry4 << 26));
+ int carry5 = (int)(this.x5 >> 25);
+ this.x6 = (int)((int)this.x6 + carry5);
+ this.x5 = (int)((int)this.x5 - (carry5 << 25));
+ int carry6 = (int)(this.x6 >> 26);
+ this.x7 = (int)((int)this.x7 + carry6);
+ this.x6 = (int)((int)this.x6 - (carry6 << 26));
+ int carry7 = (int)(this.x7 >> 25);
+ this.x8 = (int)((int)this.x8 + carry7);
+ this.x7 = (int)((int)this.x7 - (carry7 << 25));
+ int carry8 = (int)(this.x8 >> 26);
+ this.x9 = (int)((int)this.x9 + carry8);
+ this.x8 = (int)((int)this.x8 - (carry8 << 26));
+ int carry9 = (int)(this.x9 >> 25);
+ this.x9 = (int)((int)this.x9 - (carry9 << 25));
+
+ output[ 0] = (byte)(this.x0 >> 0);
+ output[ 1] = (byte)(this.x0 >> 8);
+ output[ 2] = (byte)(this.x0 >> 16);
+ output[ 3] = (byte)((this.x0 >> 24) | (this.x1 << 2));
+ output[ 4] = (byte)(this.x1 >> 6);
+ output[ 5] = (byte)(this.x1 >> 14);
+ output[ 6] = (byte)((this.x1 >> 22) | (this.x2 << 3));
+ output[ 7] = (byte)(this.x2 >> 5);
+ output[ 8] = (byte)(this.x2 >> 13);
+ output[ 9] = (byte)((this.x2 >> 21) | (this.x3 << 5));
+ output[10] = (byte)(this.x3 >> 3);
+ output[11] = (byte)(this.x3 >> 11);
+ output[12] = (byte)((this.x3 >> 19) | (this.x4 << 6));
+ output[13] = (byte)(this.x4 >> 2);
+ output[14] = (byte)(this.x4 >> 10);
+ output[15] = (byte)(this.x4 >> 18);
+ output[16] = (byte)(this.x5 >> 0);
+ output[17] = (byte)(this.x5 >> 8);
+ output[18] = (byte)(this.x5 >> 16);
+ output[19] = (byte)((this.x5 >> 24) | (this.x6 << 1));
+ output[20] = (byte)(this.x6 >> 7);
+ output[21] = (byte)(this.x6 >> 15);
+ output[22] = (byte)((this.x6 >> 23) | (this.x7 << 3));
+ output[23] = (byte)(this.x7 >> 5);
+ output[24] = (byte)(this.x7 >> 13);
+ output[25] = (byte)((this.x7 >> 21) | (this.x8 << 4));
+ output[26] = (byte)(this.x8 >> 4);
+ output[27] = (byte)(this.x8 >> 12);
+ output[28] = (byte)((this.x8 >> 20) | (this.x9 << 6));
+ output[29] = (byte)(this.x9 >> 2);
+ output[30] = (byte)(this.x9 >> 10);
+ output[31] = (byte)(this.x9 >> 18);
+ }
+
+ /// <summary>
+ /// Set the field element to `0`
+ /// </summary>
+ public static FieldElement Zero()
+ {
+ return new FieldElement();
+ }
+
+ /// <summary>
+ /// Set the field element to `1`
+ /// </summary>
+ public static FieldElement One()
+ {
+ FieldElement result = Zero();
+ result.x0 = 1;
+ return result;
+ }
+
+ /// <summary>
+ /// Add two field elements
+ /// </summary>
+ public static void Add(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ output.x0 = a.x0 + b.x0;
+ output.x1 = a.x1 + b.x1;
+ output.x2 = a.x2 + b.x2;
+ output.x3 = a.x3 + b.x3;
+ output.x4 = a.x4 + b.x4;
+ output.x5 = a.x5 + b.x5;
+ output.x6 = a.x6 + b.x6;
+ output.x7 = a.x7 + b.x7;
+ output.x8 = a.x8 + b.x8;
+ output.x9 = a.x9 + b.x9;
+ }
+
+ /// <summary>
+ /// Subtract two field elements
+ /// </summary>
+ public static void Sub(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ output.x0 = a.x0 - b.x0;
+ output.x1 = a.x1 - b.x1;
+ output.x2 = a.x2 - b.x2;
+ output.x3 = a.x3 - b.x3;
+ output.x4 = a.x4 - b.x4;
+ output.x5 = a.x5 - b.x5;
+ output.x6 = a.x6 - b.x6;
+ output.x7 = a.x7 - b.x7;
+ output.x8 = a.x8 - b.x8;
+ output.x9 = a.x9 - b.x9;
+ }
+
+ /// <summary>
+ /// Multiply two field elements
+ /// </summary>
+ public static void Multiply(ref FieldElement output, ref FieldElement a, ref FieldElement b)
+ {
+ int b1_19 = 19 * b.x1;
+ int b2_19 = 19 * b.x2;
+ int b3_19 = 19 * b.x3;
+ int b4_19 = 19 * b.x4;
+ int b5_19 = 19 * b.x5;
+ int b6_19 = 19 * b.x6;
+ int b7_19 = 19 * b.x7;
+ int b8_19 = 19 * b.x8;
+ int b9_19 = 19 * b.x9;
+
+ int a1_2 = 2 * a.x1;
+ int a3_2 = 2 * a.x3;
+ int a5_2 = 2 * a.x5;
+ int a7_2 = 2 * a.x7;
+ int a9_2 = 2 * a.x9;
+
+ long a0b0 = (long)a.x0 * (long)b.x0;
+ long a0b1 = (long)a.x0 * (long)b.x1;
+ long a0b2 = (long)a.x0 * (long)b.x2;
+ long a0b3 = (long)a.x0 * (long)b.x3;
+ long a0b4 = (long)a.x0 * (long)b.x4;
+ long a0b5 = (long)a.x0 * (long)b.x5;
+ long a0b6 = (long)a.x0 * (long)b.x6;
+ long a0b7 = (long)a.x0 * (long)b.x7;
+ long a0b8 = (long)a.x0 * (long)b.x8;
+ long a0b9 = (long)a.x0 * (long)b.x9;
+ long a1b0 = (long)a.x1 * (long)b.x0;
+ long a1b1_2 = (long)a1_2 * (long)b.x1;
+ long a1b2 = (long)a.x1 * (long)b.x2;
+ long a1b3_2 = (long)a1_2 * (long)b.x3;
+ long a1b4 = (long)a.x1 * (long)b.x4;
+ long a1b5_2 = (long)a1_2 * (long)b.x5;
+ long a1b6 = (long)a.x1 * (long)b.x6;
+ long a1b7_2 = (long)a1_2 * (long)b.x7;
+ long a1b8 = (long)a.x1 * (long)b.x8;
+ long a1b9_38 = (long)a1_2 * (long)b9_19;
+ long a2b0 = (long)a.x2 * (long)b.x0;
+ long a2b1 = (long)a.x2 * (long)b.x1;
+ long a2b2 = (long)a.x2 * (long)b.x2;
+ long a2b3 = (long)a.x2 * (long)b.x3;
+ long a2b4 = (long)a.x2 * (long)b.x4;
+ long a2b5 = (long)a.x2 * (long)b.x5;
+ long a2b6 = (long)a.x2 * (long)b.x6;
+ long a2b7 = (long)a.x2 * (long)b.x7;
+ long a2b8_19 = (long)a.x2 * (long)b8_19;
+ long a2b9_19 = (long)a.x2 * (long)b9_19;
+ long a3b0 = (long)a.x3 * (long)b.x0;
+ long a3b1_2 = (long)a3_2 * (long)b.x1;
+ long a3b2 = (long)a.x3 * (long)b.x2;
+ long a3b3_2 = (long)a3_2 * (long)b.x3;
+ long a3b4 = (long)a.x3 * (long)b.x4;
+ long a3b5_2 = (long)a3_2 * (long)b.x5;
+ long a3b6 = (long)a.x3 * (long)b.x6;
+ long a3b7_38 = (long)a3_2 * (long)b7_19;
+ long a3b8_19 = (long)a.x3 * (long)b8_19;
+ long a3b9_38 = (long)a3_2 * (long)b9_19;
+ long a4b0 = (long)a.x4 * (long)b.x0;
+ long a4b1 = (long)a.x4 * (long)b.x1;
+ long a4b2 = (long)a.x4 * (long)b.x2;
+ long a4b3 = (long)a.x4 * (long)b.x3;
+ long a4b4 = (long)a.x4 * (long)b.x4;
+ long a4b5 = (long)a.x4 * (long)b.x5;
+ long a4b6_19 = (long)a.x4 * (long)b6_19;
+ long a4b7_19 = (long)a.x4 * (long)b7_19;
+ long a4b8_19 = (long)a.x4 * (long)b8_19;
+ long a4b9_19 = (long)a.x4 * (long)b9_19;
+ long a5b0 = (long)a.x5 * (long)b.x0;
+ long a5b1_2 = (long)a5_2 * (long)b.x1;
+ long a5b2 = (long)a.x5 * (long)b.x2;
+ long a5b3_2 = (long)a5_2 * (long)b.x3;
+ long a5b4 = (long)a.x5 * (long)b.x4;
+ long a5b5_38 = (long)a5_2 * (long)b5_19;
+ long a5b6_19 = (long)a.x5 * (long)b6_19;
+ long a5b7_38 = (long)a5_2 * (long)b7_19;
+ long a5b8_19 = (long)a.x5 * (long)b8_19;
+ long a5b9_38 = (long)a5_2 * (long)b9_19;
+ long a6b0 = (long)a.x6 * (long)b.x0;
+ long a6b1 = (long)a.x6 * (long)b.x1;
+ long a6b2 = (long)a.x6 * (long)b.x2;
+ long a6b3 = (long)a.x6 * (long)b.x3;
+ long a6b4_19 = (long)a.x6 * (long)b4_19;
+ long a6b5_19 = (long)a.x6 * (long)b5_19;
+ long a6b6_19 = (long)a.x6 * (long)b6_19;
+ long a6b7_19 = (long)a.x6 * (long)b7_19;
+ long a6b8_19 = (long)a.x6 * (long)b8_19;
+ long a6b9_19 = (long)a.x6 * (long)b9_19;
+ long a7b0 = (long)a.x7 * (long)b.x0;
+ long a7b1_2 = (long)a7_2 * (long)b.x1;
+ long a7b2 = (long)a.x7 * (long)b.x2;
+ long a7b3_38 = (long)a7_2 * (long)b3_19;
+ long a7b4_19 = (long)a.x7 * (long)b4_19;
+ long a7b5_38 = (long)a7_2 * (long)b5_19;
+ long a7b6_19 = (long)a.x7 * (long)b6_19;
+ long a7b7_38 = (long)a7_2 * (long)b7_19;
+ long a7b8_19 = (long)a.x7 * (long)b8_19;
+ long a7b9_38 = (long)a7_2 * (long)b9_19;
+ long a8b0 = (long)a.x8 * (long)b.x0;
+ long a8b1 = (long)a.x8 * (long)b.x1;
+ long a8b2_19 = (long)a.x8 * (long)b2_19;
+ long a8b3_19 = (long)a.x8 * (long)b3_19;
+ long a8b4_19 = (long)a.x8 * (long)b4_19;
+ long a8b5_19 = (long)a.x8 * (long)b5_19;
+ long a8b6_19 = (long)a.x8 * (long)b6_19;
+ long a8b7_19 = (long)a.x8 * (long)b7_19;
+ long a8b8_19 = (long)a.x8 * (long)b8_19;
+ long a8b9_19 = (long)a.x8 * (long)b9_19;
+ long a9b0 = (long)a.x9 * (long)b.x0;
+ long a9b1_38 = (long)a9_2 * (long)b1_19;
+ long a9b2_19 = (long)a.x9 * (long)b2_19;
+ long a9b3_38 = (long)a9_2 * (long)b3_19;
+ long a9b4_19 = (long)a.x9 * (long)b4_19;
+ long a9b5_38 = (long)a9_2 * (long)b5_19;
+ long a9b6_19 = (long)a.x9 * (long)b6_19;
+ long a9b7_38 = (long)a9_2 * (long)b7_19;
+ long a9b8_19 = (long)a.x9 * (long)b8_19;
+ long a9b9_38 = (long)a9_2 * (long)b9_19;
+
+ long h0 = a0b0 + a1b9_38 + a2b8_19 + a3b7_38 + a4b6_19 + a5b5_38 + a6b4_19 + a7b3_38 + a8b2_19 + a9b1_38;
+ long h1 = a0b1 + a1b0 + a2b9_19 + a3b8_19 + a4b7_19 + a5b6_19 + a6b5_19 + a7b4_19 + a8b3_19 + a9b2_19;
+ long h2 = a0b2 + a1b1_2 + a2b0 + a3b9_38 + a4b8_19 + a5b7_38 + a6b6_19 + a7b5_38 + a8b4_19 + a9b3_38;
+ long h3 = a0b3 + a1b2 + a2b1 + a3b0 + a4b9_19 + a5b8_19 + a6b7_19 + a7b6_19 + a8b5_19 + a9b4_19;
+ long h4 = a0b4 + a1b3_2 + a2b2 + a3b1_2 + a4b0 + a5b9_38 + a6b8_19 + a7b7_38 + a8b6_19 + a9b5_38;
+ long h5 = a0b5 + a1b4 + a2b3 + a3b2 + a4b1 + a5b0 + a6b9_19 + a7b8_19 + a8b7_19 + a9b6_19;
+ long h6 = a0b6 + a1b5_2 + a2b4 + a3b3_2 + a4b2 + a5b1_2 + a6b0 + a7b9_38 + a8b8_19 + a9b7_38;
+ long h7 = a0b7 + a1b6 + a2b5 + a3b4 + a4b3 + a5b2 + a6b1 + a7b0 + a8b9_19 + a9b8_19;
+ long h8 = a0b8 + a1b7_2 + a2b6 + a3b5_2 + a4b4 + a5b3_2 + a6b2 + a7b1_2 + a8b0 + a9b9_38;
+ long h9 = a0b9 + a1b8 + a2b7 + a3b6 + a4b5 + a5b4 + a6b3 + a7b2 + a8b1 + a9b0;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+
+ long carry1 = (h1 + (1L << 24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry5 = (h5 + (1L << 24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+
+ long carry3 = (h3 + (1L << 24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry7 = (h7 + (1L << 24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ long carry9 = (h9 + (1L << 24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+
+ carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Square a field element
+ /// </summary>
+ public static void Square(ref FieldElement output, ref FieldElement a)
+ {
+ int a0_2 = a.x0 * 2;
+ int a1_2 = a.x1 * 2;
+ int a2_2 = a.x2 * 2;
+ int a3_2 = a.x3 * 2;
+ int a4_2 = a.x4 * 2;
+ int a5_2 = a.x5 * 2;
+ int a6_2 = a.x6 * 2;
+ int a7_2 = a.x7 * 2;
+
+ int a5_38 = a.x5 * 38;
+ int a6_19 = a.x6 * 19;
+ int a7_38 = a.x7 * 38;
+ int a8_19 = a.x8 * 19;
+ int a9_38 = a.x9 * 38;
+
+ long a0a0 = (long)a.x0 * (long)a.x0;
+ long a0a1_2 = (long)a0_2 * (long)a.x1;
+ long a0a2_2 = (long)a0_2 * (long)a.x2;
+ long a0a3_2 = (long)a0_2 * (long)a.x3;
+ long a0a4_2 = (long)a0_2 * (long)a.x4;
+ long a0a5_2 = (long)a0_2 * (long)a.x5;
+ long a0a6_2 = (long)a0_2 * (long)a.x6;
+ long a0a7_2 = (long)a0_2 * (long)a.x7;
+ long a0a8_2 = (long)a0_2 * (long)a.x8;
+ long a0a9_2 = (long)a0_2 * (long)a.x9;
+ long a1a1_2 = (long)a1_2 * (long)a.x1;
+ long a1a2_2 = (long)a1_2 * (long)a.x2;
+ long a1a3_4 = (long)a1_2 * (long)a3_2;
+ long a1a4_2 = (long)a1_2 * (long)a.x4;
+ long a1a5_4 = (long)a1_2 * (long)a5_2;
+ long a1a6_2 = (long)a1_2 * (long)a.x6;
+ long a1a7_4 = (long)a1_2 * (long)a7_2;
+ long a1a8_2 = (long)a1_2 * (long)a.x8;
+ long a1a9_76 = (long)a1_2 * (long)a9_38;
+ long a2a2 = (long)a.x2 * (long)a.x2;
+ long a2a3_2 = (long)a2_2 * (long)a.x3;
+ long a2a4_2 = (long)a2_2 * (long)a.x4;
+ long a2a5_2 = (long)a2_2 * (long)a.x5;
+ long a2a6_2 = (long)a2_2 * (long)a.x6;
+ long a2a7_2 = (long)a2_2 * (long)a.x7;
+ long a2a8_38 = (long)a2_2 * (long)a8_19;
+ long a2a9_38 = (long)a.x2 * (long)a9_38;
+ long a3a3_2 = (long)a3_2 * (long)a.x3;
+ long a3a4_2 = (long)a3_2 * (long)a.x4;
+ long a3a5_4 = (long)a3_2 * (long)a5_2;
+ long a3a6_2 = (long)a3_2 * (long)a.x6;
+ long a3a7_76 = (long)a3_2 * (long)a7_38;
+ long a3a8_38 = (long)a3_2 * (long)a8_19;
+ long a3a9_76 = (long)a3_2 * (long)a9_38;
+ long a4a4 = (long)a.x4 * (long)a.x4;
+ long a4a5_2 = (long)a4_2 * (long)a.x5;
+ long a4a6_38 = (long)a4_2 * (long)a6_19;
+ long a4a7_38 = (long)a.x4 * (long)a7_38;
+ long a4a8_38 = (long)a4_2 * (long)a8_19;
+ long a4a9_38 = (long)a.x4 * (long)a9_38;
+ long a5a5_38 = (long)a.x5 * (long)a5_38;
+ long a5a6_38 = (long)a5_2 * (long)a6_19;
+ long a5a7_76 = (long)a5_2 * (long)a7_38;
+ long a5a8_38 = (long)a5_2 * (long)a8_19;
+ long a5a9_76 = (long)a5_2 * (long)a9_38;
+ long a6a6_19 = (long)a.x6 * (long)a6_19;
+ long a6a7_38 = (long)a.x6 * (long)a7_38;
+ long a6a8_38 = (long)a6_2 * (long)a8_19;
+ long a6a9_38 = (long)a.x6 * (long)a9_38;
+ long a7a7_38 = (long)a.x7 * (long)a7_38;
+ long a7a8_38 = (long)a7_2 * (long)a8_19;
+ long a7a9_76 = (long)a7_2 * (long)a9_38;
+ long a8a8_19 = (long)a.x8 * (long)a8_19;
+ long a8a9_38 = (long)a.x8 * (long)a9_38;
+ long a9a9_38 = (long)a.x9 * (long)a9_38;
+
+ long h0 = a0a0 + a1a9_76 + a2a8_38 + a3a7_76 + a4a6_38 + a5a5_38;
+ long h1 = a0a1_2 + a2a9_38 + a3a8_38 + a4a7_38 + a5a6_38;
+ long h2 = a0a2_2 + a1a1_2 + a3a9_76 + a4a8_38 + a5a7_76 + a6a6_19;
+ long h3 = a0a3_2 + a1a2_2 + a4a9_38 + a5a8_38 + a6a7_38;
+ long h4 = a0a4_2 + a1a3_4 + a2a2 + a5a9_76 + a6a8_38 + a7a7_38;
+ long h5 = a0a5_2 + a1a4_2 + a2a3_2 + a6a9_38 + a7a8_38;
+ long h6 = a0a6_2 + a1a5_4 + a2a4_2 + a3a3_2 + a7a9_76 + a8a8_19;
+ long h7 = a0a7_2 + a1a6_2 + a2a5_2 + a3a4_2 + a8a9_38;
+ long h8 = a0a8_2 + a1a7_4 + a2a6_2 + a3a5_4 + a4a4 + a9a9_38;
+ long h9 = a0a9_2 + a1a8_2 + a2a7_2 + a3a6_2 + a4a5_2;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+
+ long carry1 = (h1 + (1L << 24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry5 = (h5 + (1L << 24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+
+ long carry3 = (h3 + (1L << 24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry7 = (h7 + (1L << 24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ long carry9 = (h9 + (1L << 24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+
+ carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Multiplay a field element by 121666
+ /// </summary>
+ public static void Multiply121666(ref FieldElement output, ref FieldElement a)
+ {
+ long h0 = (long)a.x0 * 121666L;
+ long h1 = (long)a.x1 * 121666L;
+ long h2 = (long)a.x2 * 121666L;
+ long h3 = (long)a.x3 * 121666L;
+ long h4 = (long)a.x4 * 121666L;
+ long h5 = (long)a.x5 * 121666L;
+ long h6 = (long)a.x6 * 121666L;
+ long h7 = (long)a.x7 * 121666L;
+ long h8 = (long)a.x8 * 121666L;
+ long h9 = (long)a.x9 * 121666L;
+
+ long carry9 = (h9 + (1L<<24)) >> 25;
+ h0 += carry9 * 19;
+ h9 -= carry9 << 25;
+ long carry1 = (h1 + (1L<<24)) >> 25;
+ h2 += carry1;
+ h1 -= carry1 << 25;
+ long carry3 = (h3 + (1L<<24)) >> 25;
+ h4 += carry3;
+ h3 -= carry3 << 25;
+ long carry5 = (h5 + (1L<<24)) >> 25;
+ h6 += carry5;
+ h5 -= carry5 << 25;
+ long carry7 = (h7 + (1L<<24)) >> 25;
+ h8 += carry7;
+ h7 -= carry7 << 25;
+
+ long carry0 = (h0 + (1L << 25)) >> 26;
+ h1 += carry0;
+ h0 -= carry0 << 26;
+ long carry2 = (h2 + (1L << 25)) >> 26;
+ h3 += carry2;
+ h2 -= carry2 << 26;
+ long carry4 = (h4 + (1L << 25)) >> 26;
+ h5 += carry4;
+ h4 -= carry4 << 26;
+ long carry6 = (h6 + (1L << 25)) >> 26;
+ h7 += carry6;
+ h6 -= carry6 << 26;
+ long carry8 = (h8 + (1L << 25)) >> 26;
+ h9 += carry8;
+ h8 -= carry8 << 26;
+
+ output.x0 = (int)h0;
+ output.x1 = (int)h1;
+ output.x2 = (int)h2;
+ output.x3 = (int)h3;
+ output.x4 = (int)h4;
+ output.x5 = (int)h5;
+ output.x6 = (int)h6;
+ output.x7 = (int)h7;
+ output.x8 = (int)h8;
+ output.x9 = (int)h9;
+ }
+
+ /// <summary>
+ /// Invert a field element
+ /// </summary>
+ public static void Invert(ref FieldElement output, ref FieldElement a)
+ {
+ FieldElement t0 = new FieldElement();
+ Square(ref t0, ref a);
+
+ FieldElement t1 = new FieldElement();
+ Square(ref t1, ref t0);
+ Square(ref t1, ref t1);
+
+ FieldElement t2= new FieldElement();
+ Multiply(ref t1, ref a, ref t1);
+ Multiply(ref t0, ref t0, ref t1);
+ Square(ref t2, ref t0);
+ //Square(ref t2, ref t2);
+
+ Multiply(ref t1, ref t1, ref t2);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 5; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 10; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ FieldElement t3 = new FieldElement();
+ Multiply(ref t2, ref t2, ref t1);
+ Square(ref t3, ref t2);
+ for (int ii = 1; ii < 20; ++ii)
+ {
+ Square(ref t3, ref t3);
+ }
+
+ Multiply(ref t2, ref t3, ref t2);
+ Square(ref t2, ref t2);
+ for (int ii = 1; ii < 10; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t2, ref t1);
+ for (int ii = 1; ii < 50; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t2, ref t2, ref t1);
+ Square(ref t3, ref t2);
+ for (int ii = 1; ii < 100; ++ii)
+ {
+ Square(ref t3, ref t3);
+ }
+
+ Multiply(ref t2, ref t3, ref t2);
+ Square(ref t2, ref t2);
+ for (int ii = 1; ii < 50; ++ii)
+ {
+ Square(ref t2, ref t2);
+ }
+
+ Multiply(ref t1, ref t2, ref t1);
+ Square(ref t1, ref t1);
+ for (int ii = 1; ii < 5; ++ii)
+ {
+ Square(ref t1, ref t1);
+ }
+
+ Multiply(ref output, ref t1, ref t0);
+ }
+
+ /// <summary>
+ /// Swaps `a` and `b` if `swap` is 1
+ /// </summary>
+ public static void ConditionalSwap(ref FieldElement a, ref FieldElement b, int swap)
+ {
+ Debug.Assert(swap == 0 || swap == 1);
+ swap = -swap;
+
+ FieldElement temp = new FieldElement
+ {
+ x0 = swap & (a.x0 ^ b.x0),
+ x1 = swap & (a.x1 ^ b.x1),
+ x2 = swap & (a.x2 ^ b.x2),
+ x3 = swap & (a.x3 ^ b.x3),
+ x4 = swap & (a.x4 ^ b.x4),
+ x5 = swap & (a.x5 ^ b.x5),
+ x6 = swap & (a.x6 ^ b.x6),
+ x7 = swap & (a.x7 ^ b.x7),
+ x8 = swap & (a.x8 ^ b.x8),
+ x9 = swap & (a.x9 ^ b.x9),
+ };
+
+ a.x0 ^= temp.x0;
+ a.x1 ^= temp.x1;
+ a.x2 ^= temp.x2;
+ a.x3 ^= temp.x3;
+ a.x4 ^= temp.x4;
+ a.x5 ^= temp.x5;
+ a.x6 ^= temp.x6;
+ a.x7 ^= temp.x7;
+ a.x8 ^= temp.x8;
+ a.x9 ^= temp.x9;
+
+ b.x0 ^= temp.x0;
+ b.x1 ^= temp.x1;
+ b.x2 ^= temp.x2;
+ b.x3 ^= temp.x3;
+ b.x4 ^= temp.x4;
+ b.x5 ^= temp.x5;
+ b.x6 ^= temp.x6;
+ b.x7 ^= temp.x7;
+ b.x8 ^= temp.x8;
+ b.x9 ^= temp.x9;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs
new file mode 100644
index 0000000..35609fc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/DataReceivedEventArgs.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ public struct DataReceivedEventArgs
+ {
+ public readonly Connection Sender;
+
+ /// <summary>
+ /// The bytes received from the client.
+ /// </summary>
+ public readonly MessageReader Message;
+
+ /// <summary>
+ /// The <see cref="SendOption"/> the data was sent with.
+ /// </summary>
+ public readonly SendOption SendOption;
+
+ public DataReceivedEventArgs(Connection sender, MessageReader msg, SendOption sendOption)
+ {
+ this.Sender = sender;
+ this.Message = msg;
+ this.SendOption = sendOption;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs
new file mode 100644
index 0000000..a7fb05c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/DisconnectedEventArgs.cs
@@ -0,0 +1,24 @@
+using System;
+
+namespace Hazel
+{
+ public class DisconnectedEventArgs : EventArgs
+ {
+ /// <summary>
+ /// Optional disconnect reason. May be null.
+ /// </summary>
+ public readonly string Reason;
+
+ /// <summary>
+ /// Optional data sent with a disconnect message. May be null.
+ /// You must not recycle this. If you need the message outside of a callback, you should copy it.
+ /// </summary>
+ public readonly MessageReader Message;
+
+ public DisconnectedEventArgs(string reason, MessageReader message)
+ {
+ this.Reason = reason;
+ this.Message = message;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
new file mode 100644
index 0000000..65df39e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/AesGcmRecordProtection.cs
@@ -0,0 +1,147 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// *_AES_128_GCM_* cipher suite
+ /// </summary>
+ public class Aes128GcmRecordProtection: IRecordProtection
+ {
+ private const int ImplicitNonceSize = 4;
+ private const int ExplicitNonceSize = 8;
+
+ private readonly Aes128Gcm serverWriteCipher;
+ private readonly Aes128Gcm clientWriteCipher;
+
+ private readonly ByteSpan serverWriteIV;
+ private readonly ByteSpan clientWriteIV;
+
+ /// <summary>
+ /// Create a new instance of the AES128_GCM record protection
+ /// </summary>
+ /// <param name="masterSecret">Shared secret</param>
+ /// <param name="serverRandom">Server random data</param>
+ /// <param name="clientRandom">Client random data</param>
+ public Aes128GcmRecordProtection(ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ ByteSpan combinedRandom = new byte[serverRandom.Length + clientRandom.Length];
+ serverRandom.CopyTo(combinedRandom);
+ clientRandom.CopyTo(combinedRandom.Slice(serverRandom.Length));
+
+ // Expand master_secret to encryption keys
+ const int ExpandedSize = 0
+ + 0 // mac_key_length
+ + 0 // mac_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + Aes128Gcm.KeySize // enc_key_length
+ + ImplicitNonceSize // fixed_iv_length
+ + ImplicitNonceSize // fixed_iv_length
+ ;
+
+ ByteSpan expandedKey = new byte[ExpandedSize];
+ PrfSha256.ExpandSecret(expandedKey, masterSecret, PrfLabel.KEY_EXPANSION, combinedRandom);
+
+ ByteSpan clientWriteKey = expandedKey.Slice(0, Aes128Gcm.KeySize);
+ ByteSpan serverWriteKey = expandedKey.Slice(Aes128Gcm.KeySize, Aes128Gcm.KeySize);
+ this.clientWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize, ImplicitNonceSize);
+ this.serverWriteIV = expandedKey.Slice(2 * Aes128Gcm.KeySize + ImplicitNonceSize, ImplicitNonceSize);
+
+ this.serverWriteCipher = new Aes128Gcm(serverWriteKey);
+ this.clientWriteCipher = new Aes128Gcm(clientWriteKey);
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.serverWriteCipher.Dispose();
+ this.clientWriteCipher.Dispose();
+ }
+
+ /// <inheritdoc />
+ private static int GetEncryptedSizeImpl(int dataSize)
+ {
+ return dataSize + Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetEncryptedSize(int dataSize)
+ {
+ return GetEncryptedSizeImpl(dataSize);
+ }
+
+ private static int GetDecryptedSizeImpl(int dataSize)
+ {
+ return dataSize - Aes128Gcm.CiphertextOverhead;
+ }
+
+ /// <inheritdoc />
+ public int GetDecryptedSize(int dataSize)
+ {
+ return GetDecryptedSizeImpl(dataSize);
+ }
+
+ /// <inheritdoc />
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ EncryptPlaintext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static void EncryptPlaintext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetEncryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)input.Length;
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ cipher.Seal(output, nonce, input, associatedData);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.serverWriteCipher, this.serverWriteIV);
+ }
+
+ /// <inheritdoc />
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ return DecryptCiphertext(output, input, ref record, this.clientWriteCipher, this.clientWriteIV);
+ }
+
+ private static bool DecryptCiphertext(ByteSpan output, ByteSpan input, ref Record record, Aes128Gcm cipher, ByteSpan writeIV)
+ {
+ Debug.Assert(output.Length >= GetDecryptedSizeImpl(input.Length));
+
+ // Build GCM nonce (authenticated data)
+ ByteSpan nonce = new byte[ImplicitNonceSize + ExplicitNonceSize];
+ writeIV.CopyTo(nonce);
+ nonce.WriteBigEndian16(record.Epoch, ImplicitNonceSize);
+ nonce.WriteBigEndian48(record.SequenceNumber, ImplicitNonceSize + 2);
+
+ // Serialize record as additional data
+ Record plaintextRecord = record;
+ plaintextRecord.Length = (ushort)GetDecryptedSizeImpl(input.Length);
+ ByteSpan associatedData = new byte[Record.Size];
+ plaintextRecord.Encode(associatedData);
+
+ return cipher.Open(output, nonce, input, associatedData);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
new file mode 100644
index 0000000..61f41d3
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsConnectionListener.cs
@@ -0,0 +1,1424 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+using Hazel.Udp.FewerThreads;
+using Hazel.Crypto;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Listens for new UDP-DTLS connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsConnectionListener : ThreadLimitedUdpConnectionListener
+ {
+ private const int MaxCertFragmentSizeV0 = 1200;
+
+ // Min MTU - UDP+IP header - 1 (for good measure. :))
+ private const int MaxCertFragmentSizeV1 = 576 - 32 - 1;
+
+ /// <summary>
+ /// Current state of handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ ExpectingHello,
+ ExpectingClientKeyExchange,
+ ExpectingChangeCipherSpec,
+ ExpectingFinish
+ }
+
+ /// <summary>
+ /// State to manage the current epoch `N`
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ public IRecordProtection PreviousRecordProtection;
+
+ // Need to keep these around so we can re-transmit our
+ // last handshake record flight
+ public ByteSpan ExpectedClientFinishedVerification;
+ public ByteSpan ServerFinishedVerification;
+ public ulong NextOutgoingSequenceForPreviousEpoch;
+ }
+
+ /// <summary>
+ /// State to manage the transition from the current
+ /// epoch `N` to epoch `N+1`
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+ public CipherSuite SelectedCipherSuite;
+
+ public ulong NextOutgoingSequence;
+
+ public IHandshakeCipherSuite Handshake;
+ public IRecordProtection RecordProtection;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public Sha256Stream VerificationStream;
+
+ public ByteSpan ClientVerification;
+ public ByteSpan ServerVerification;
+
+ }
+
+ /// <summary>
+ /// Per-peer state
+ /// </summary>
+ sealed class PeerData : IDisposable
+ {
+ public ushort Epoch;
+ public bool CanHandleApplicationData;
+
+ public HazelDtlsSessionInfo Session;
+
+ public CurrentEpoch CurrentEpoch;
+ public NextEpoch NextEpoch;
+
+ public ConnectionId ConnectionId;
+
+ public readonly List<ByteSpan> QueuedApplicationDataMessage = new List<ByteSpan>();
+ public readonly ConcurrentBag<MessageReader> ApplicationData = new ConcurrentBag<MessageReader>();
+ public readonly ProtocolVersion ProtocolVersion;
+
+ public DateTime StartOfNegotiation;
+
+ public PeerData(ConnectionId connectionId, ulong nextExpectedSequenceNumber, ProtocolVersion protocolVersion)
+ {
+ ByteSpan block = new byte[2 * Finished.Size];
+ this.CurrentEpoch.ServerFinishedVerification = block.Slice(0, Finished.Size);
+ this.CurrentEpoch.ExpectedClientFinishedVerification = block.Slice(Finished.Size, Finished.Size);
+ this.ProtocolVersion = protocolVersion;
+
+ ResetPeer(connectionId, nextExpectedSequenceNumber);
+ }
+
+ public void ResetPeer(ConnectionId connectionId, ulong nextExpectedSequenceNumber)
+ {
+ Dispose();
+
+ this.Epoch = 0;
+ this.CanHandleApplicationData = false;
+ this.QueuedApplicationDataMessage.Clear();
+
+ this.CurrentEpoch.NextOutgoingSequence = 2; // Account for our ClientHelloVerify
+ this.CurrentEpoch.NextExpectedSequence = nextExpectedSequenceNumber;
+ this.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.CurrentEpoch.RecordProtection = NullRecordProtection.Instance;
+ this.CurrentEpoch.PreviousRecordProtection = null;
+ this.CurrentEpoch.ServerFinishedVerification.SecureClear();
+ this.CurrentEpoch.ExpectedClientFinishedVerification.SecureClear();
+
+ this.NextEpoch.State = HandshakeState.ExpectingHello;
+ this.NextEpoch.RecordProtection = null;
+ this.NextEpoch.Handshake = null;
+ this.NextEpoch.ClientRandom = new byte[Random.Size];
+ this.NextEpoch.ServerRandom = new byte[Random.Size];
+ this.NextEpoch.VerificationStream = new Sha256Stream();
+ this.NextEpoch.ClientVerification = new byte[Finished.Size];
+ this.NextEpoch.ServerVerification = new byte[Finished.Size];
+
+ this.ConnectionId = connectionId;
+
+ this.StartOfNegotiation = DateTime.UtcNow;
+ }
+
+ public void Dispose()
+ {
+ this.CurrentEpoch.RecordProtection?.Dispose();
+ this.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ this.NextEpoch.RecordProtection?.Dispose();
+ this.NextEpoch.Handshake?.Dispose();
+ this.NextEpoch.VerificationStream?.Dispose();
+
+ while (this.ApplicationData.TryTake(out var msg))
+ {
+ try
+ {
+ msg.Recycle();
+ }
+ catch { }
+ }
+ }
+ }
+
+ private RandomNumberGenerator random;
+
+ // Private key component of certificate's public key
+ private ByteSpan encodedCertificate;
+ private RSA certificatePrivateKey;
+
+ // HMAC key to validate ClientHello cookie
+ private ThreadedHmacHelper hmacHelper;
+ private HMAC CurrentCookieHmac {
+ get
+ {
+ return hmacHelper.GetCurrentCookieHmacsForThread();
+ }
+ }
+ private HMAC PreviousCookieHmac
+ {
+ get
+ {
+ return hmacHelper.GetPreviousCookieHmacsForThread();
+ }
+ }
+
+ private ConcurrentStack<ConnectionId> staleConnections = new ConcurrentStack<ConnectionId>();
+ private readonly ConcurrentDictionary<IPEndPoint, PeerData> existingPeers = new ConcurrentDictionary<IPEndPoint, PeerData>();
+ public int PeerCount => this.existingPeers.Count;
+
+ // TODO: Move these into an DtlsErrorStatistics class
+ public int NonPeerNonHelloPacketsDropped;
+ public int NonVerifiedFinishedHandshake;
+ public int NonPeerVerifyHelloRequests;
+ public int PeerVerifyHelloRequests;
+
+ private int connectionSerial_unsafe = 0;
+
+ private Timer staleConnectionUpkeep;
+
+ /// <summary>
+ /// Create a new instance of the DTLS listener
+ /// </summary>
+ /// <param name="numWorkers"></param>
+ /// <param name="endPoint"></param>
+ /// <param name="logger"></param>
+ /// <param name="ipMode"></param>
+ public DtlsConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ : base(numWorkers, endPoint, logger, ipMode)
+ {
+ this.random = RandomNumberGenerator.Create();
+
+ this.staleConnectionUpkeep = new Timer(this.HandleStaleConnections, null, 2500, 1000);
+ this.hmacHelper = new ThreadedHmacHelper(logger);
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ this.staleConnectionUpkeep.Dispose();
+
+ this.random?.Dispose();
+ this.random = null;
+
+ this.hmacHelper?.Dispose();
+ this.hmacHelper = null;
+
+ foreach (var pair in this.existingPeers)
+ {
+ pair.Value.Dispose();
+ }
+ this.existingPeers.Clear();
+ }
+
+ /// <summary>
+ /// Set the certificate key pair for the listener
+ /// </summary>
+ /// <param name="certificate">Certificate for the server</param>
+ public void SetCertificate(X509Certificate2 certificate)
+ {
+ if (!certificate.HasPrivateKey)
+ {
+ throw new ArgumentException("Certificate must have a private key attached", nameof(certificate));
+ }
+
+ RSA privateKey = certificate.GetRSAPrivateKey();
+ if (privateKey == null)
+ {
+ throw new ArgumentException("Certificate must be signed by an RSA key", nameof(certificate));
+ }
+
+ this.certificatePrivateKey?.Dispose();
+ this.certificatePrivateKey = privateKey;
+
+ this.encodedCertificate = Certificate.Encode(certificate);
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network.
+ ///
+ /// This is primarily a wrapper around ProcessIncomingMessage
+ /// to ensure `reader.Recycle()` is always called
+ /// </summary>
+ protected override void ReadCallback(MessageReader reader, IPEndPoint peerAddress, ConnectionId connectionId)
+ {
+ try
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ this.ProcessIncomingMessage(message, peerAddress);
+ }
+ finally
+ {
+ reader.Recycle();
+ }
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram from the network
+ /// </summary>
+ private void ProcessIncomingMessage(ByteSpan message, IPEndPoint peerAddress)
+ {
+ PeerData peer = null;
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ lock (this.existingPeers)
+ {
+ if (!this.existingPeers.TryGetValue(peerAddress, out peer))
+ {
+ HandleNonPeerRecord(message, peerAddress);
+ return;
+ }
+ }
+ }
+
+ ConnectionId peerConnectionId;
+
+ lock (peer)
+ {
+ peerConnectionId = peer.ConnectionId;
+
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (message.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, peer.ProtocolVersion, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed record from `{peerAddress}` Length({record.Length}) AvailableBytes({message.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = message.Slice(0, record.Length);
+ message = message.Slice(record.Length);
+
+ // Early-out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && !peer.CanHandleApplicationData)
+ {
+ this.Logger.WriteInfo($"Dropping ApplicationData record from `{peerAddress}` Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != peer.Epoch)
+ {
+ // Handle existing client negotiating a new connection
+ if (record.Epoch == 0 && record.ContentType == ContentType.Handshake)
+ {
+ ByteSpan handshakePayload = recordPayload;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ handshakePayload = handshakePayload.Slice(Handshake.Size);
+
+ if (handshake.FragmentOffset != 0 || handshake.Length != handshake.FragmentLength)
+ {
+ this.Logger.WriteError($"Dropping fragmented re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+ this.Logger.WriteVerbose($"Dropping non-ClientHello re-negotiation Handshake from `{peerAddress}`");
+ continue;
+ }
+ else if (handshakePayload.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed re-negotiation Handshake from `{peerAddress}`: Length({handshake.Length}) AvailableBytes({handshakePayload.Length})");
+ }
+
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, recordPayload, handshakePayload))
+ {
+ return;
+ }
+ continue;
+ }
+
+ this.Logger.WriteVerbose($"Dropping bad-epoch record from `{peerAddress}` RecordEpoch({record.Epoch}) CurrentEpoch({peer.Epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(peer.CurrentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < peer.CurrentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.Logger.WriteInfo($"Dropping too-old record from `{peerAddress}` Sequence({record.SequenceNumber}) Expected({peer.CurrentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((peer.CurrentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.Logger.WriteInfo($"Dropping duplicate record from `{peerAddress}`");
+ continue;
+ }
+ }
+
+ // Validate record authenticity
+ int decryptedSize = peer.CurrentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ if (decryptedSize < 0)
+ {
+ this.Logger.WriteInfo($"Dropping malformed record: Length {recordPayload.Length} Decrypted length: {decryptedSize}");
+ continue;
+ }
+
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ if (!peer.CurrentEpoch.RecordProtection.DecryptCiphertextFromClient(decryptedPayload, recordPayload, ref record))
+ {
+ this.Logger.WriteVerbose($"Dropping non-authentic {record.ContentType} record from `{peerAddress}`");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update our squence number bookeeping
+ if (record.SequenceNumber >= peer.CurrentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - peer.CurrentEpoch.NextExpectedSequence);
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ peer.CurrentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.Logger.WriteVerbose($"Record type {record.ContentType} ({peer.NextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.Logger.WriteError($"Dropping unexpected ChangeChiperSpec record from `{peerAddress}` State({peer.NextEpoch.State})");
+ break;
+ }
+ else if (peer.NextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+
+ this.Logger.WriteError($"Dropping ChangeCipherSpec message from `{peerAddress}`: No pending record protection");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.Logger.WriteError($"Dropping malformed ChangeCipherSpec message from `{peerAddress}`");
+ break;
+ }
+
+ // Migrate to the next epoch
+ peer.Epoch = peer.NextEpoch.Epoch;
+ peer.CanHandleApplicationData = false; // Need a Finished message
+ peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch = peer.CurrentEpoch.NextOutgoingSequence;
+ peer.CurrentEpoch.PreviousRecordProtection?.Dispose();
+ peer.CurrentEpoch.PreviousRecordProtection = peer.CurrentEpoch.RecordProtection;
+ peer.CurrentEpoch.RecordProtection = peer.NextEpoch.RecordProtection;
+ peer.CurrentEpoch.NextOutgoingSequence = 1;
+ peer.CurrentEpoch.NextExpectedSequence = 1;
+ peer.CurrentEpoch.PreviousSequenceWindowBitmask = 0;
+ peer.NextEpoch.ClientVerification.CopyTo(peer.CurrentEpoch.ExpectedClientFinishedVerification);
+ peer.NextEpoch.ServerVerification.CopyTo(peer.CurrentEpoch.ServerFinishedVerification);
+
+ peer.NextEpoch.State = HandshakeState.ExpectingHello;
+ peer.NextEpoch.Handshake?.Dispose();
+ peer.NextEpoch.Handshake = null;
+ peer.NextEpoch.NextOutgoingSequence = 1;
+ peer.NextEpoch.RecordProtection = null;
+ peer.NextEpoch.VerificationStream.Reset();
+ peer.NextEpoch.ClientVerification.SecureClear();
+ peer.NextEpoch.ServerVerification.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.Logger.WriteError($"Dropping unsupported Alert record from `{peerAddress}`");
+ break;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(peer, peerAddress, ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ peer.ApplicationData.Add(reader);
+ break;
+ }
+ }
+ }
+
+ // The peer lock must be exited before leaving the DtlsConnectionListener context to prevent deadlocks
+ // because ApplicationData processing may reenter this context
+ while (peer.ApplicationData.TryTake(out var appMsg))
+ {
+ base.ReadCallback(appMsg, peerAddress, peerConnectionId);
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer's network address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(PeerData peer, IPEndPoint peerAddress, ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple handshake payloads
+ while (message.Length > 0)
+ {
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (message.Length < handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping malformed Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ ByteSpan payload = message.Slice(0, (int)message.Length);
+ message = message.Slice((int)handshake.Length);
+ originalMessage = originalMessage.Slice(0, Handshake.Size + (int)handshake.Length);
+
+ // We do not support fragmented handshake messages
+ // from the client
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ this.Logger.WriteError($"Dropping fragmented Handshake message from `{peerAddress}` Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ ByteSpan packet;
+ ByteSpan writer;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"Received handshake {handshake.MessageType} ({peer.NextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.ClientHello:
+ if (!this.HandleClientHello(peer, peerAddress, ref record, ref handshake, originalMessage, payload))
+ {
+ return false;
+ }
+ break;
+
+ case HandshakeType.ClientKeyExchange:
+ if (peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ this.Logger.WriteError($"Dropping unexpected ClientKeyExchange message form `{peerAddress}` State({peer.NextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 5)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientKeyExchange message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[peer.NextEpoch.Handshake.SharedKeySize()];
+ if (!peer.NextEpoch.Handshake.VerifyClientMessageAndGenerateSharedKey(sharedSecret, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientKeyExchange message from `{peerAddress}`");
+ return false;
+ }
+
+ // Record incoming ClientKeyExchange message
+ // to verification stream
+ peer.NextEpoch.VerificationStream.AddData(originalMessage);
+
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ peer.NextEpoch.ClientRandom.CopyTo(randomSeed);
+ peer.NextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create the record protection for the upcoming epoch
+ switch (peer.NextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ peer.NextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , peer.NextEpoch.ServerRandom
+ , peer.NextEpoch.ClientRandom);
+ break;
+
+ default:
+ Debug.Assert(false, $"How did we agree to a cipher suite {peer.NextEpoch.SelectedCipherSuite} we can't create?");
+ this.Logger.WriteError($"Dropping ClientKeyExchange message from `{peerAddress}` Unsuppored cipher suite");
+ return false;
+ }
+
+ // Generate verification signatures
+ ByteSpan handshakeStreamHash = new byte[Sha256Stream.DigestSize];
+ peer.NextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeStreamHash);
+
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ClientVerification
+ , masterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeStreamHash
+ );
+ PrfSha256.ExpandSecret(
+ peer.NextEpoch.ServerVerification
+ , masterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeStreamHash
+ );
+
+
+ // Update handshake state
+ masterSecret.SecureClear();
+ peer.NextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ break;
+
+ case HandshakeType.Finished:
+ // Unlike other handshake messages, this is
+ // for the current epoch - not the next epoch
+
+ // Cannot process a Finished message for
+ // epoch 0
+ if (peer.Epoch == 0)
+ {
+ this.Logger.WriteError($"Dropping Finished message for 0-epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message when we
+ // are negotiating the next epoch
+ else if (peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ this.Logger.WriteError($"Dropping Finished message while negotiating new epoch from `{peerAddress}`");
+ continue;
+ }
+ // Cannot process a Finished message without
+ // verify data
+ else if (peer.CurrentEpoch.ExpectedClientFinishedVerification.Length != Finished.Size || peer.CurrentEpoch.ServerFinishedVerification.Length != Finished.Size)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch without verify data?");
+
+ this.Logger.WriteError($"Dropping Finished message (no verify data) from `{peerAddress}`");
+ return false;
+ }
+ // Cannot process a Finished message without
+ // record protection for the previous epoch
+ else if (peer.CurrentEpoch.PreviousRecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed server.
+ Debug.Assert(false, "How do we have an established non-zero epoch with record protection for the previous epoch?");
+
+ this.Logger.WriteError($"Dropping Finished message from `{peerAddress}`: No previous epoch record protection");
+ return false;
+ }
+
+ // Verify message sequence
+ if (handshake.MessageSequence != 6)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence Finished message from `{peerAddress}` MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the client has the correct
+ // handshake sequence
+ if (payload.Length != Finished.Size)
+ {
+ this.Logger.WriteError($"Dropping malformed Finished message from `{peerAddress}`");
+ return false;
+ }
+ else if (1 != Crypto.Const.ConstantCompareSpans(payload, peer.CurrentEpoch.ExpectedClientFinishedVerification))
+ {
+
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-verified Finished Handshake from `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonVerifiedFinishedHandshake);
+#endif
+
+ // Abort the connection here
+ //
+ // The client is either broken, or
+ // doen not agree on our epoch settings.
+ //
+ // Either way, there is not a feasible
+ // way to progress the connection.
+ MarkConnectionAsStale(peer.ConnectionId);
+ this.existingPeers.TryRemove(peerAddress, out _);
+
+ return false;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Describe our ChangeCipherSpec+Finished
+ Handshake outgoingHandshake = new Handshake();
+ outgoingHandshake.MessageType = HandshakeType.Finished;
+ outgoingHandshake.Length = Finished.Size;
+ outgoingHandshake.MessageSequence = 7;
+ outgoingHandshake.FragmentOffset = 0;
+ outgoingHandshake.FragmentLength = outgoingHandshake.Length;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = protocolVersion;
+ changeCipherSpecRecord.Epoch = (ushort)(peer.Epoch - 1);
+ changeCipherSpecRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+ changeCipherSpecRecord.Length = (ushort)peer.CurrentEpoch.PreviousRecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ int plaintextFinishedPayloadSize = Handshake.Size + (int)outgoingHandshake.Length;
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = protocolVersion;
+ finishedRecord.Epoch = peer.Epoch;
+ finishedRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(plaintextFinishedPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the flight into wire format
+ packet = new byte[Record.Size + changeCipherSpecRecord.Length + Record.Size + finishedRecord.Length];
+ writer = packet;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+
+ ByteSpan startOfFinishedRecord = packet.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ outgoingHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.CurrentEpoch.ServerFinishedVerification.CopyTo(writer);
+
+ // Protect the ChangeChipherSpec record
+ peer.CurrentEpoch.PreviousRecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, changeCipherSpecRecord.Length),
+ packet.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished Handshake record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, plaintextFinishedPayloadSize),
+ ref finishedRecord
+ );
+
+ // Current epoch can now handle application data
+ peer.CanHandleApplicationData = true;
+
+ base.QueueRawData(packet, peerAddress);
+ break;
+
+ // Drop messages that we do not support
+ case HandshakeType.CertificateVerify:
+ this.Logger.WriteError($"Dropping unsupported Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+
+ // Drop messages that originate from the server
+ case HandshakeType.HelloRequest:
+ case HandshakeType.ServerHello:
+ case HandshakeType.HelloVerifyRequest:
+ case HandshakeType.Certificate:
+ case HandshakeType.ServerKeyExchange:
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.ServerHelloDone:
+ this.Logger.WriteError($"Dropping server Handshake message from `{peerAddress}` MessageType({handshake.MessageType})");
+ continue;
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle a ClientHello message for a peer
+ /// </summary>
+ /// <param name="peer">Originating peer</param>
+ /// <param name="peerAddress">Peer address</param>
+ /// <param name="record">Parent record</param>
+ /// <param name="handshake">Parent Handshake header</param>
+ /// <param name="payload">Handshake payload</param>
+ private bool HandleClientHello(PeerData peer, IPEndPoint peerAddress, ref Record record, ref Handshake handshake, ByteSpan originalMessage, ByteSpan payload)
+ {
+ // Verify message sequence
+ if (handshake.MessageSequence != 0)
+ {
+ this.Logger.WriteError($"Dropping bad-sequence ClientHello from `{peerAddress}` MessageSequence({handshake.MessageSequence})`");
+ return true;
+ }
+
+ // Make sure we can handle a ClientHello message
+ if (peer.NextEpoch.State != HandshakeState.ExpectingHello && peer.NextEpoch.State != HandshakeState.ExpectingClientKeyExchange)
+ {
+ // Always handle ClientHello for epoch 0
+ if (record.Epoch != 0)
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peer}` Not expecting ClientHello");
+ return true;
+ }
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+ if (!ClientHello.Parse(out ClientHello clientHello, protocolVersion, payload))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello Handshake message from `{peerAddress}`");
+ return false;
+ }
+
+ // Find an acceptable cipher suite we can use
+ CipherSuite selectedCipherSuite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256;
+ if (!clientHello.ContainsCipherSuite(CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256) || !clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` No compatible cipher suite");
+ return false;
+ }
+
+ // If this message was not signed by us,
+ // request a signed message before doing anything else
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+ ulong outgoingSequence = 1;
+ IRecordProtection recordProtection = NullRecordProtection.Instance;
+ if (record.Epoch != 0)
+ {
+ outgoingSequence = peer.CurrentEpoch.NextExpectedSequence;
+ ++peer.CurrentEpoch.NextOutgoingSequenceForPreviousEpoch;
+
+ recordProtection = peer.CurrentEpoch.RecordProtection;
+ }
+
+#if DEBUG
+ this.Logger.WriteError($"Sending HelloVerifyRequest to peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.PeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, outgoingSequence, record.Epoch, recordProtection, protocolVersion);
+ return true;
+ }
+ }
+
+ // Client is initiating a brand new connection. We need
+ // to destroy the existing connection and establish a
+ // new session.
+ if (record.Epoch == 0 && peer.Epoch != 0)
+ {
+ ConnectionId oldConnectionId = peer.ConnectionId;
+ peer.ResetPeer(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1);
+
+ // Inform the parent layer that the existing
+ // connection should be abandoned.
+ MarkConnectionAsStale(oldConnectionId);
+ }
+
+ // Determine if this is an original message, or a retransmission
+ bool recordMessagesForVerifyData = false;
+ if (peer.NextEpoch.State == HandshakeState.ExpectingHello)
+ {
+ // Create our handhake cipher suite
+ IHandshakeCipherSuite handshakeCipherSuite = null;
+ switch (selectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ if (clientHello.ContainsCurve(NamedCurve.x25519))
+ {
+ handshakeCipherSuite = new X25519EcdheRsaSha256(this.random);
+ }
+ else
+ {
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 cipher suite");
+ return false;
+ }
+
+ break;
+
+ default:
+ this.Logger.WriteError($"Dropping ClientHello from `{peerAddress}` Could not create handshake cipher suite");
+ return false;
+ }
+
+ peer.Session = clientHello.Session;
+
+ // Update the state of our epoch transition
+ peer.NextEpoch.Epoch = (ushort)(record.Epoch + 1);
+ peer.NextEpoch.State = HandshakeState.ExpectingClientKeyExchange;
+ peer.NextEpoch.SelectedCipherSuite = selectedCipherSuite;
+ peer.NextEpoch.Handshake = handshakeCipherSuite;
+ clientHello.Random.CopyTo(peer.NextEpoch.ClientRandom);
+ peer.NextEpoch.ServerRandom.FillWithRandom(this.random);
+ recordMessagesForVerifyData = true;
+
+#if DEBUG
+ this.Logger.WriteVerbose($"ClientRandom: {peer.NextEpoch.ClientRandom} ServerRandom: {peer.NextEpoch.ServerRandom}");
+#endif
+
+ // Copy the original ClientHello
+ // handshake to our verification stream
+ peer.NextEpoch.VerificationStream.AddData(
+ originalMessage.Slice(
+ 0
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+ }
+
+ // The initial record flight from the server
+ // contains the following Handshake messages:
+ // * ServerHello
+ // * Certificate
+ // * ServerKeyExchange
+ // * ServerHelloDone
+ //
+ // The Certificate message is almost always
+ // too large to fit into a single datagram,
+ // so it is pre-fragmented
+ // (see `SetCertificates`). Therefore, we
+ // need to send multiple record packets for
+ // this flight.
+ //
+ // The first record contains the ServerHello
+ // handshake message, as well as the first
+ // portion of the Certificate message.
+ //
+ // We then send a record packet until the
+ // entire Certificate message has been sent
+ // to the client.
+ //
+ // The final record packet contains the
+ // ServerKeyExchange and the ServerHelloDone
+ // messages.
+
+ // Describe first record of the flight
+ ServerHello serverHello = new ServerHello();
+ serverHello.ServerProtocolVersion = protocolVersion;
+ serverHello.Random = peer.NextEpoch.ServerRandom;
+ serverHello.CipherSuite = selectedCipherSuite;
+
+ Handshake serverHelloHandshake = new Handshake();
+ serverHelloHandshake.MessageType = HandshakeType.ServerHello;
+ serverHelloHandshake.Length = ServerHello.MinSize;
+ serverHelloHandshake.MessageSequence = 1;
+ serverHelloHandshake.FragmentOffset = 0;
+ serverHelloHandshake.FragmentLength = serverHelloHandshake.Length;
+
+ int maxCertFragmentSize = peer.Session.Version == 0 ? MaxCertFragmentSizeV0 : MaxCertFragmentSizeV1;
+
+ // The first certificate data needs to leave room for
+ // * Record header
+ // * ServerHello header
+ // * ServerHello payload
+ // * Certificate header
+
+ var certificateData = this.encodedCertificate;
+ int initialCertPadding = Record.Size + Handshake.Size + serverHello.Size + Handshake.Size;
+ int certInitialFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - initialCertPadding);
+
+ Handshake certificateHandshake = new Handshake();
+ certificateHandshake.MessageType = HandshakeType.Certificate;
+ certificateHandshake.Length = (uint)certificateData.Length;
+ certificateHandshake.MessageSequence = 2;
+ certificateHandshake.FragmentOffset = 0;
+ certificateHandshake.FragmentLength = (uint)certInitialFragmentSize;
+
+ int initialRecordPayloadSize = 0
+ + Handshake.Size + serverHello.Size
+ + Handshake.Size + (int)certificateHandshake.FragmentLength
+ ;
+ Record initialRecord = new Record();
+ initialRecord.ContentType = ContentType.Handshake;
+ initialRecord.ProtocolVersion = protocolVersion;
+ initialRecord.Epoch = peer.Epoch;
+ initialRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ initialRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(initialRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert initial record of the flight to
+ // wire format
+ ByteSpan packet = new byte[Record.Size + initialRecord.Length];
+ ByteSpan writer = packet;
+ initialRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certInitialFragmentSize).CopyTo(writer);
+ certificateData = certificateData.Slice(certInitialFragmentSize);
+
+ // Protect initial record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, initialRecord.Length),
+ packet.Slice(Record.Size, initialRecordPayloadSize),
+ ref initialRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ Handshake fullCeritficateHandshake = certificateHandshake;
+ fullCeritficateHandshake.FragmentLength = fullCeritficateHandshake.Length;
+
+ packet = new byte[Handshake.Size + ServerHello.MinSize + Handshake.Size];
+ writer = packet;
+ serverHelloHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ serverHello.Encode(writer);
+ writer = writer.Slice(ServerHello.MinSize);
+ fullCeritficateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ peer.NextEpoch.VerificationStream.AddData(packet);
+ peer.NextEpoch.VerificationStream.AddData(this.encodedCertificate);
+ }
+
+ // Process additional certificate records
+ // Subsequent certificate data needs to leave room for
+ // * Record header
+ // * Certificate header
+ const int CertPadding = Record.Size + Handshake.Size;
+ while (certificateData.Length > 0)
+ {
+ int certFragmentSize = Math.Min(certificateData.Length, maxCertFragmentSize - CertPadding);
+
+ certificateHandshake.FragmentOffset += certificateHandshake.FragmentLength;
+ certificateHandshake.FragmentLength = (uint)certFragmentSize;
+
+ int additionalRecordPayloadSize = Handshake.Size + (int)certificateHandshake.FragmentLength;
+ Record additionalRecord = new Record();
+ additionalRecord.ContentType = ContentType.Handshake;
+ additionalRecord.ProtocolVersion = protocolVersion;
+ additionalRecord.Epoch = peer.Epoch;
+ additionalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ additionalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(additionalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert record to wire format
+ packet = new byte[Record.Size + additionalRecord.Length];
+ writer = packet;
+ additionalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ certificateHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ certificateData.Slice(0, certFragmentSize).CopyTo(writer);
+
+ certificateData = certificateData.Slice(certFragmentSize);
+
+ // Protect record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, additionalRecord.Length),
+ packet.Slice(Record.Size, additionalRecordPayloadSize),
+ ref additionalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ // Describe final record of the flight
+ Handshake serverKeyExchangeHandshake = new Handshake();
+ serverKeyExchangeHandshake.MessageType = HandshakeType.ServerKeyExchange;
+ serverKeyExchangeHandshake.Length = (uint)peer.NextEpoch.Handshake.CalculateServerMessageSize(this.certificatePrivateKey);
+ serverKeyExchangeHandshake.MessageSequence = 3;
+ serverKeyExchangeHandshake.FragmentOffset = 0;
+ serverKeyExchangeHandshake.FragmentLength = serverKeyExchangeHandshake.Length;
+
+ Handshake serverHelloDoneHandshake = new Handshake();
+ serverHelloDoneHandshake.MessageType = HandshakeType.ServerHelloDone;
+ serverHelloDoneHandshake.Length = 0;
+ serverHelloDoneHandshake.MessageSequence = 4;
+ serverHelloDoneHandshake.FragmentOffset = 0;
+ serverHelloDoneHandshake.FragmentLength = 0;
+
+ int finalRecordPayloadSize = 0
+ + Handshake.Size + (int)serverKeyExchangeHandshake.Length
+ + Handshake.Size + (int)serverHelloDoneHandshake.Length
+ ;
+ Record finalRecord = new Record();
+ finalRecord.ContentType = ContentType.Handshake;
+ finalRecord.ProtocolVersion = protocolVersion;
+ finalRecord.Epoch = peer.Epoch;
+ finalRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ finalRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(finalRecordPayloadSize);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Convert final record of the flight to wire
+ // format
+ packet = new byte[Record.Size + finalRecord.Length];
+ writer = packet;
+ finalRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ serverKeyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ peer.NextEpoch.Handshake.EncodeServerKeyExchangeMessage(writer, this.certificatePrivateKey);
+ writer = writer.Slice((int)serverKeyExchangeHandshake.Length);
+ serverHelloDoneHandshake.Encode(writer);
+
+ // Record record payload for verification
+ if (recordMessagesForVerifyData)
+ {
+ peer.NextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ packet.Offset + Record.Size
+ , finalRecordPayloadSize
+ )
+ );
+ }
+
+ // Protect final record of the flight
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, finalRecord.Length),
+ packet.Slice(Record.Size, finalRecordPayloadSize),
+ ref finalRecord
+ );
+
+ base.QueueRawData(packet, peerAddress);
+
+ return true;
+ }
+
+ /// <summary>
+ /// Handle an incoming packet that is not tied to an existing peer
+ /// </summary>
+ /// <param name="message">Incoming datagram</param>
+ /// <param name="peerAddress">Originating address</param>
+ private void HandleNonPeerRecord(ByteSpan message, IPEndPoint peerAddress)
+ {
+ Record record;
+ if (!Record.Parse(out record, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed record from non-peer `{peerAddress}`");
+ return;
+ }
+ message = message.Slice(Record.Size);
+
+ // The protocol only supports receiving a single record
+ // from a non-peer.
+ if (record.Length != message.Length)
+ {
+ // NOTE(mendsley): This isn't always fatal.
+ // However, this is an indication that something
+ // fishy is going on. In the best case, there's a
+ // bug on the client or in the UDP stack (some
+ // stacks don't both to verify the checksum). In the
+ // worst case we're dealing with a malicious actor.
+ // In the malicious case, we'll end up dropping the
+ // connection later in the process.
+ if (message.Length < record.Length)
+ {
+ this.Logger.WriteInfo($"Dropping bad record from non-peer `{peerAddress}`. Msg length {message.Length} < {record.Length}");
+ return;
+ }
+ }
+
+ // We only accept zero-epoch records from non-peers
+ if (record.Epoch != 0)
+ {
+ ///NOTE(mendsley): Not logging anything here, as
+ /// this could easily be latent data arriving from a
+ /// recently disconnected peer.
+ return;
+ }
+
+ // We only accept Handshake protocol messages from non-peers
+ if (record.ContentType != ContentType.Handshake)
+ {
+ this.Logger.WriteError($"Dropping non-handhsake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ ByteSpan originalMessage = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.Logger.WriteError($"Dropping malformed handshake message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // We only accept ClientHello messages from non-peers
+ if (handshake.MessageType != HandshakeType.ClientHello)
+ {
+#if DEBUG
+ this.Logger.WriteError($"Dropping non-ClientHello ({handshake.MessageType}) message from non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerNonHelloPacketsDropped);
+#endif
+ return;
+ }
+ message = message.Slice(Handshake.Size);
+
+ if (!ClientHello.Parse(out ClientHello clientHello, expectedProtocolVersion: null, message))
+ {
+ this.Logger.WriteError($"Dropping malformed ClientHello message from non-peer `{peerAddress}`");
+ return;
+ }
+
+ // If this ClientHello is not signed by us, request the
+ // client send us a signed message
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.CurrentCookieHmac))
+ {
+ if (!HelloVerifyRequest.VerifyCookie(clientHello.Cookie, peerAddress, this.PreviousCookieHmac))
+ {
+#if DEBUG
+ this.Logger.WriteVerbose($"Sending HelloVerifyRequest to non-peer `{peerAddress}`");
+#else
+ Interlocked.Increment(ref this.NonPeerVerifyHelloRequests);
+#endif
+ this.SendHelloVerifyRequest(peerAddress, 1, 0, NullRecordProtection.Instance, clientHello.ClientProtocolVersion);
+ return;
+ }
+ }
+
+ // Allocate state for the new peer and register it
+ PeerData peer = new PeerData(this.AllocateConnectionId(peerAddress), record.SequenceNumber + 1, clientHello.ClientProtocolVersion);
+ this.ProcessHandshake(peer, peerAddress, ref record, originalMessage);
+ this.existingPeers[peerAddress] = peer;
+ }
+
+ //Send a HelloVerifyRequest handshake message to a peer
+ private void SendHelloVerifyRequest(IPEndPoint peerAddress, ulong recordSequence, ushort epoch, IRecordProtection recordProtection, ProtocolVersion protocolVersion)
+ {
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.HelloVerifyRequest;
+ handshake.Length = HelloVerifyRequest.Size;
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ int plaintextPayloadSize = Handshake.Size + (int)handshake.Length;
+
+ Record record = new Record();
+ record.ContentType = ContentType.Handshake;
+ record.ProtocolVersion = protocolVersion;
+ record.Epoch = epoch;
+ record.SequenceNumber = recordSequence;
+ record.Length = (ushort)recordProtection.GetEncryptedSize(plaintextPayloadSize);
+
+ // Encode record to wire format
+ ByteSpan packet = new byte[Record.Size + record.Length];
+ ByteSpan writer = packet;
+ record.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ HelloVerifyRequest.Encode(writer, peerAddress, this.CurrentCookieHmac, protocolVersion);
+
+ // Protect record payload
+ recordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, record.Length),
+ packet.Slice(Record.Size, plaintextPayloadSize),
+ ref record
+ );
+
+ base.QueueRawData(packet, peerAddress);
+ }
+
+ /// <summary>
+ /// Handle a requrest to send a datagram to the network
+ /// </summary>
+ protected override void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint)
+ {
+ PeerData peer;
+ if (!this.existingPeers.TryGetValue(remoteEndPoint, out peer))
+ {
+ // Drop messages if we don't know how to send them
+ return;
+ }
+
+ lock (peer)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ ByteSpan copyOfSpan = new byte[span.Length];
+ span.CopyTo(copyOfSpan);
+
+ peer.QueuedApplicationDataMessage.Add(copyOfSpan);
+ return;
+ }
+
+ ProtocolVersion protocolVersion = peer.ProtocolVersion;
+
+ // Send any queued application data now
+ for (int ii = 0, nn = peer.QueuedApplicationDataMessage.Count; ii != nn; ++ii)
+ {
+ ByteSpan queuedSpan = peer.QueuedApplicationDataMessage[ii];
+
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(queuedSpan.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ queuedSpan.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, queuedSpan.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ peer.QueuedApplicationDataMessage.Clear();
+
+ {
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.ApplicationData;
+ outgoingRecord.ProtocolVersion = protocolVersion;
+ outgoingRecord.Epoch = peer.Epoch;
+ outgoingRecord.SequenceNumber = peer.CurrentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)peer.CurrentEpoch.RecordProtection.GetEncryptedSize(span.Length);
+ ++peer.CurrentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ span.CopyTo(writer);
+
+ // Protect the record
+ peer.CurrentEpoch.RecordProtection.EncryptServerPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, span.Length),
+ ref outgoingRecord
+ );
+
+ base.QueueRawData(packet, remoteEndPoint);
+ }
+ }
+ }
+
+ private void HandleStaleConnections(object _)
+ {
+ TimeSpan maxAge = TimeSpan.FromSeconds(2.5f);
+ DateTime now = DateTime.UtcNow;
+ foreach (KeyValuePair<IPEndPoint, PeerData> kvp in this.existingPeers)
+ {
+ PeerData peer = kvp.Value;
+ lock (peer)
+ {
+ if (peer.Epoch == 0 || peer.NextEpoch.State != HandshakeState.ExpectingHello)
+ {
+ TimeSpan negotiationAge = now - peer.StartOfNegotiation;
+ if (negotiationAge > maxAge)
+ {
+ MarkConnectionAsStale(peer.ConnectionId);
+ }
+ }
+ }
+ }
+
+ ConnectionId connectionId;
+ while (this.staleConnections.TryPop(out connectionId))
+ {
+ ThreadLimitedUdpServerConnection connection;
+ if (this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ connection.Disconnect("Stale Connection", null);
+ }
+ }
+ }
+
+ protected void MarkConnectionAsStale(ConnectionId connectionId)
+ {
+ if (this.allConnections.ContainsKey(connectionId))
+ {
+ this.staleConnections.Push(connectionId);
+ }
+ }
+
+ /// <inheritdoc />
+ internal override void RemovePeerRecord(ConnectionId connectionId)
+ {
+ if (this.existingPeers.TryRemove(connectionId.EndPoint, out var peer))
+ {
+ peer.Dispose();
+ }
+ }
+
+ /// <summary>
+ /// Allocate a new connection id
+ /// </summary>
+ private ConnectionId AllocateConnectionId(IPEndPoint endPoint)
+ {
+ int rawSerialId = Interlocked.Increment(ref this.connectionSerial_unsafe);
+ return ConnectionId.Create(endPoint, rawSerialId);
+ }
+
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
new file mode 100644
index 0000000..4da2051
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/DtlsUnityConnection.cs
@@ -0,0 +1,1246 @@
+using Hazel.Crypto;
+using Hazel.Udp;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Connects to a UDP-DTLS server
+ /// </summary>
+ /// <inheritdoc />
+ public class DtlsUnityConnection : UnityUdpClientConnection
+ {
+ /// <summary>
+ /// Current state of the handshake sequence
+ /// </summary>
+ enum HandshakeState
+ {
+ Initializing,
+
+ ExpectingServerHello,
+ ExpectingCertificate,
+ ExpectingServerKeyExchange,
+ ExpectingServerHelloDone,
+
+ ExpectingChangeCipherSpec,
+ ExpectingFinished,
+
+ Established,
+ }
+
+ /// <summary>
+ /// State data for the current epoch
+ /// </summary>
+ struct CurrentEpoch
+ {
+ public ulong NextOutgoingSequence;
+
+ public ulong NextExpectedSequence;
+ public ulong PreviousSequenceWindowBitmask;
+
+ public IRecordProtection RecordProtection;
+ }
+
+ struct FragmentRange
+ {
+ public int Offset;
+ public int Length;
+ }
+
+ /// <summary>
+ /// State data for the next epoch
+ /// </summary>
+ struct NextEpoch
+ {
+ public ushort Epoch;
+
+ public HandshakeState State;
+
+ public ulong NextOutgoingSequence;
+
+ public DateTime NegotiationStartTime;
+ public DateTime NextPacketResendTime;
+ public int PacketResendCount;
+
+ public CipherSuite SelectedCipherSuite;
+ public IRecordProtection RecordProtection;
+ public IHandshakeCipherSuite Handshake;
+ public ByteSpan Cookie;
+ public Sha256Stream VerificationStream;
+ public RSA ServerPublicKey;
+
+ public ByteSpan ClientRandom;
+ public ByteSpan ServerRandom;
+
+ public ByteSpan MasterSecret;
+ public ByteSpan ServerVerification;
+
+ public List<FragmentRange> CertificateFragments;
+ public ByteSpan CertificatePayload;
+ }
+
+ struct QueuedAppData
+ {
+ public byte[] Bytes;
+ public byte SendOption;
+ public Action AckCallback;
+ }
+
+ private const ProtocolVersion DtlsVersion = ProtocolVersion.UDP;
+
+ internal byte HazelSessionVersion = HazelDtlsSessionInfo.CurrentClientSessionVersion;
+
+ private readonly object syncRoot = new object();
+ private readonly RandomNumberGenerator random = RandomNumberGenerator.Create();
+
+ private ushort epoch;
+ private CurrentEpoch currentEpoch;
+ private NextEpoch nextEpoch;
+ private TimeSpan handshakeResendTimeout = TimeSpan.FromMilliseconds(200);
+
+ private readonly Queue<QueuedAppData> queuedApplicationData = new Queue<QueuedAppData>();
+
+ private X509Certificate2Collection serverCertificates = new X509Certificate2Collection();
+
+ public bool HandshakeComplete
+ {
+ get
+ {
+ lock (this.syncRoot)
+ {
+ return this.nextEpoch.State == HandshakeState.Established;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Create a new instance of the DTLS connection
+ /// </summary>
+ /// <inheritdoc />
+ public DtlsUnityConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger, remoteEndPoint, ipMode)
+ {
+ this.nextEpoch.ServerRandom = new byte[Random.Size];
+ this.nextEpoch.ClientRandom = new byte[Random.Size];
+ this.nextEpoch.ServerVerification = new byte[Finished.Size];
+ this.nextEpoch.CertificateFragments = new List<FragmentRange>();
+
+ this.ResetConnectionState();
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ }
+ }
+
+ /// <summary>
+ /// Set the list of valid server certificates
+ /// </summary>
+ /// <param name="certificateCollection">
+ /// List of certificates of authentic servers
+ /// </param>
+ public void SetValidServerCertificates(X509Certificate2Collection certificateCollection)
+ {
+ lock (this.syncRoot)
+ {
+ foreach (X509Certificate2 certificate in certificateCollection)
+ {
+ if (!(certificate.PublicKey.Key is RSA))
+ {
+ throw new ArgumentException("Certificate must be signed with an RSA key", nameof(certificateCollection));
+ }
+ }
+
+ this.serverCertificates = certificateCollection;
+ }
+ }
+
+ /// <summary>
+ /// Set the packet resend timer for handshake messages
+ /// </summary>
+ public void SetHandshakeResendTimeout(TimeSpan timeout)
+ {
+ lock (this.syncRoot)
+ {
+ this.handshakeResendTimeout = timeout;
+ }
+ }
+
+ /// <summary>
+ /// Reset existing connection state
+ /// </summary>
+ private void ResetConnectionState()
+ {
+ this.currentEpoch.NextOutgoingSequence = 1;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+ this.currentEpoch.RecordProtection?.Dispose();
+ this.currentEpoch.RecordProtection = NullRecordProtection.Instance;
+
+ this.nextEpoch.Epoch = 1;
+ this.nextEpoch.State = HandshakeState.Initializing;
+ this.nextEpoch.NextOutgoingSequence = 1;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection?.Dispose();
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Handshake = null;
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream?.Dispose();
+ this.nextEpoch.VerificationStream = new Sha256Stream();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+ this.epoch = 0;
+ while (this.queuedApplicationData.TryDequeue(out _)) ;
+ }
+
+ /// <summary>
+ /// Abort the existing connection and restart the process
+ /// </summary>
+ protected override void RestartConnection()
+ {
+ lock (this.syncRoot)
+ {
+ this.ResetConnectionState();
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+ this.SendClientHello(isRetransmit: false);
+ }
+
+ base.RestartConnection();
+ }
+
+ /// <inheritdoc />
+ protected override void ResendPacketsIfNeeded()
+ {
+ lock (this.syncRoot)
+ {
+ // Check if we need to resend handshake message
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ DateTime now = DateTime.UtcNow;
+ if (now >= this.nextEpoch.NextPacketResendTime)
+ {
+ double negotiationDurationMs = (now - this.nextEpoch.NegotiationStartTime).TotalMilliseconds;
+ this.nextEpoch.PacketResendCount++;
+
+ if ((this.ResendLimit > 0 && this.nextEpoch.PacketResendCount > this.ResendLimit)
+ || negotiationDurationMs > this.DisconnectTimeoutMs)
+ {
+ this.DisconnectInternal(HazelInternalErrors.DtlsNegotiationFailed, $"DTLS negotiation failed after {this.nextEpoch.PacketResendCount} resends ({(int)negotiationDurationMs} ms).");
+ }
+ else
+ {
+ switch (this.nextEpoch.State)
+ {
+ case HandshakeState.ExpectingServerHello:
+ case HandshakeState.ExpectingCertificate:
+ case HandshakeState.ExpectingServerKeyExchange:
+ case HandshakeState.ExpectingServerHelloDone:
+ this.SendClientHello(isRetransmit: true);
+ break;
+
+ case HandshakeState.ExpectingChangeCipherSpec:
+ case HandshakeState.ExpectingFinished:
+ this.SendClientKeyExchangeFlight(isRetransmit: true);
+ break;
+
+ case HandshakeState.Established:
+ default:
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ base.ResendPacketsIfNeeded();
+ }
+
+ /// <summary>
+ /// Flush any queued application data packets
+ /// </summary>
+ private void FlushQueuedApplicationData()
+ {
+ while (this.queuedApplicationData.TryDequeue(out var queuedData))
+ {
+ base.HandleSend(queuedData.Bytes, queuedData.SendOption, queuedData.AckCallback);
+ }
+ }
+
+ /// <summary>
+ /// Request from the application to write data to the DTLS
+ /// stream. If appropriate, returns a byte span to send to
+ /// the wire.
+ /// </summary>
+ /// <param name="bytes">Plaintext bytes to write</param>
+ /// <param name="length">Length of the bytes to write</param>
+ /// <returns>
+ /// Encrypted data to put on the wire if appropriate,
+ /// otherwise an empty span
+ /// </returns>
+ private ByteSpan WriteBytesToConnectionInternal(byte[] bytes, int length)
+ {
+ lock (this.syncRoot)
+ {
+ Record outgoinRecord = new Record();
+ outgoinRecord.ContentType = ContentType.ApplicationData;
+ outgoinRecord.ProtocolVersion = DtlsVersion;
+ outgoinRecord.Epoch = this.epoch;
+ outgoinRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoinRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Encode the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoinRecord.Length];
+ ByteSpan writer = packet;
+ outgoinRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ new ByteSpan(bytes, 0, length).CopyTo(writer);
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoinRecord.Length),
+ packet.Slice(Record.Size, length),
+ ref outgoinRecord
+ );
+
+ return packet;
+ }
+ }
+
+ protected override void HandleSend(byte[] data, byte sendOption, Action ackCallback = null)
+ {
+ lock (this.syncRoot)
+ {
+ // If we're negotiating a new epoch, queue data
+ if (this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.queuedApplicationData.Enqueue(new QueuedAppData
+ {
+ Bytes = data,
+ SendOption = sendOption,
+ AckCallback = ackCallback
+ });
+
+ return;
+ }
+ }
+
+ base.HandleSend(data, sendOption, ackCallback);
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnection(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnectionSync(byte[] bytes, int length)
+ {
+ ByteSpan wireData = this.WriteBytesToConnectionInternal(bytes, length);
+ if (wireData.Length > 0)
+ {
+ Debug.Assert(wireData.Offset == 0, "Got a non-zero write data offset");
+ base.WriteBytesToConnectionSync(wireData.GetUnderlyingArray(), wireData.Length);
+ }
+ }
+
+ /// <inheritdoc />
+ protected internal override void HandleReceive(MessageReader reader, int bytesReceived)
+ {
+ ByteSpan message = new ByteSpan(reader.Buffer, reader.Offset + reader.Position, reader.BytesRemaining);
+ lock (this.syncRoot)
+ {
+ this.HandleReceive(message);
+ }
+
+ reader.Recycle();
+ }
+
+ /// <summary>
+ /// Handle an incoming datagram
+ /// </summary>
+ /// <param name="span">Bytes of the datagram</param>
+ private void HandleReceive(ByteSpan span)
+ {
+ // Each incoming packet may contain multiple DTLS
+ // records
+ while (span.Length > 0)
+ {
+ Record record;
+ if (!Record.Parse(out record, DtlsVersion, span))
+ {
+ this.logger.WriteError("Dropping malformed record");
+ return;
+ }
+ span = span.Slice(Record.Size);
+
+ if (span.Length < record.Length)
+ {
+ this.logger.WriteError($"Dropping malformed record. Length({record.Length}) Available Bytes({span.Length})");
+ return;
+ }
+
+ ByteSpan recordPayload = span.Slice(0, record.Length);
+ span = span.Slice(record.Length);
+
+ // Early out and drop ApplicationData records
+ if (record.ContentType == ContentType.ApplicationData && this.nextEpoch.State != HandshakeState.Established)
+ {
+ this.logger.WriteError("Dropping ApplicationData record. Cannot process yet");
+ continue;
+ }
+
+ // Drop records from a different epoch
+ if (record.Epoch != this.epoch)
+ {
+ this.logger.WriteError($"Dropping bad-epoch record. RecordEpoch({record.Epoch}) Epoch({this.epoch})");
+ continue;
+ }
+
+ // Prevent replay attacks by dropping records
+ // we've already processed
+ int windowIndex = (int)(this.currentEpoch.NextExpectedSequence - record.SequenceNumber - 1);
+ ulong windowMask = 1ul << windowIndex;
+ if (record.SequenceNumber < this.currentEpoch.NextExpectedSequence)
+ {
+ if (windowIndex >= 64)
+ {
+ this.logger.WriteError($"Dropping too-old record: Sequnce({record.SequenceNumber}) Expected({this.currentEpoch.NextExpectedSequence})");
+ continue;
+ }
+
+ if ((this.currentEpoch.PreviousSequenceWindowBitmask & windowMask) != 0)
+ {
+ this.logger.WriteWarning("Dropping duplicate record");
+ continue;
+ }
+ }
+
+ // Verify record authenticity
+ int decryptedSize = this.currentEpoch.RecordProtection.GetDecryptedSize(recordPayload.Length);
+ ByteSpan decryptedPayload = recordPayload.ReuseSpanIfPossible(decryptedSize);
+
+ if (!this.currentEpoch.RecordProtection.DecryptCiphertextFromServer(decryptedPayload, recordPayload, ref record))
+ {
+ this.logger.WriteError("Dropping non-authentic record");
+ return;
+ }
+
+ recordPayload = decryptedPayload;
+
+ // Update out sequence number bookkeeping
+ if (record.SequenceNumber >= this.currentEpoch.NextExpectedSequence)
+ {
+ int windowShift = (int)(record.SequenceNumber + 1 - this.currentEpoch.NextExpectedSequence);
+ this.currentEpoch.PreviousSequenceWindowBitmask <<= windowShift;
+ this.currentEpoch.NextExpectedSequence = record.SequenceNumber + 1;
+ }
+ else
+ {
+ this.currentEpoch.PreviousSequenceWindowBitmask |= windowMask;
+ }
+
+ // This is handy for debugging, but too verbose even for verbose.
+ // this.logger.WriteVerbose($"Content type was {record.ContentType} ({this.nextEpoch.State})");
+ switch (record.ContentType)
+ {
+ case ContentType.ChangeCipherSpec:
+ if (this.nextEpoch.State != HandshakeState.ExpectingChangeCipherSpec)
+ {
+ this.logger.WriteError($"Dropping unexpected ChangeCipherSpec State({this.nextEpoch.State})");
+ break;
+ }
+ else if (this.nextEpoch.RecordProtection == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "How did we receive a ChangeCipherSpec message without a pending record protection instance?");
+ break;
+ }
+
+ if (!ChangeCipherSpec.Parse(recordPayload))
+ {
+ this.logger.WriteError("Dropping malformed ChangeCipherSpec message");
+ break;
+ }
+
+ // Migrate to the next epoch
+ this.epoch = this.nextEpoch.Epoch;
+ this.currentEpoch.RecordProtection = this.nextEpoch.RecordProtection;
+ this.currentEpoch.NextOutgoingSequence = this.nextEpoch.NextOutgoingSequence;
+ this.currentEpoch.NextExpectedSequence = 1;
+ this.currentEpoch.PreviousSequenceWindowBitmask = 0;
+
+ this.nextEpoch.State = HandshakeState.ExpectingFinished;
+ this.nextEpoch.SelectedCipherSuite = CipherSuite.TLS_NULL_WITH_NULL_NULL;
+ this.nextEpoch.RecordProtection = null;
+ this.nextEpoch.Handshake?.Dispose();
+ this.nextEpoch.Cookie = ByteSpan.Empty;
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.ServerPublicKey = null;
+ this.nextEpoch.ServerRandom.SecureClear();
+ this.nextEpoch.ClientRandom.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+ break;
+
+ case ContentType.Alert:
+ this.logger.WriteError("Dropping unsupported alert record");
+ continue;
+
+ case ContentType.Handshake:
+ if (!ProcessHandshake(ref record, recordPayload))
+ {
+ return;
+ }
+ break;
+
+ case ContentType.ApplicationData:
+ // Forward data to the application
+ MessageReader reader = MessageReader.GetSized(recordPayload.Length);
+ reader.Length = recordPayload.Length;
+ recordPayload.CopyTo(reader.Buffer);
+
+ base.HandleReceive(reader, recordPayload.Length);
+ break;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Process an incoming Handshake protocol message
+ /// </summary>
+ /// <param name="record">Parent record</param>
+ /// <param name="message">Record payload</param>
+ /// <returns>
+ /// True if further processing of the underlying datagram
+ /// should be continues. Otherwise, false.
+ /// </returns>
+ private bool ProcessHandshake(ref Record record, ByteSpan message)
+ {
+ // Each record may have multiple Handshake messages
+ while (message.Length > 0)
+ {
+ ByteSpan originalPayload = message;
+
+ Handshake handshake;
+ if (!Handshake.Parse(out handshake, message))
+ {
+ this.logger.WriteError("Dropping malformed handshake message");
+ return false;
+ }
+ message = message.Slice(Handshake.Size);
+
+ // Check for fragmented messages
+ if (handshake.FragmentOffset != 0 || handshake.FragmentLength != handshake.Length)
+ {
+ // We only support fragmentation on Certificate messages
+ if (handshake.MessageType != HandshakeType.Certificate)
+ {
+ this.logger.WriteError($"Dropping fragmented handshake message Type({handshake.MessageType}) Offset({handshake.FragmentOffset}) FragmentLength({handshake.FragmentLength}) Length({handshake.Length})");
+ continue;
+ }
+
+ if (message.Length < handshake.FragmentLength)
+ {
+ this.logger.WriteError($"Dropping malformed fragmented handshake message: AvailableBytes({message.Length}) Size({handshake.FragmentLength})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.FragmentLength));
+ message = message.Slice((int)handshake.FragmentLength);
+ }
+ else
+ {
+ if (message.Length < handshake.Length)
+ {
+ this.logger.WriteError($"Dropping malformed handshake message: AvailableBytes({message.Length}) Size({handshake.Length})");
+ return false;
+ }
+
+ originalPayload = originalPayload.Slice(0, (int)(Handshake.Size + handshake.Length));
+ message = message.Slice((int)handshake.Length);
+ }
+
+ ByteSpan payload = originalPayload.Slice(Handshake.Size);
+
+#if DEBUG
+ this.logger.WriteVerbose($"Handshake record was {handshake.MessageType} (Frag: {handshake.FragmentOffset}) ({this.nextEpoch.State})");
+#endif
+ switch (handshake.MessageType)
+ {
+ case HandshakeType.HelloVerifyRequest:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected HelloVerifyRequest handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 0)
+ {
+ this.logger.WriteError($"Dropping bad-sequence HelloVerifyRequest MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ HelloVerifyRequest helloVerifyRequest;
+ if (!HelloVerifyRequest.Parse(out helloVerifyRequest, DtlsVersion, payload))
+ {
+ this.logger.WriteError("Dropping malformed HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ // If the cookie differs, save it and restart the handshake
+ if (this.nextEpoch.Cookie.Length == helloVerifyRequest.Cookie.Length
+ && Const.ConstantCompareSpans(this.nextEpoch.Cookie, helloVerifyRequest.Cookie) == 1)
+ {
+ this.logger.WriteWarning("Dropping duplicate HelloVerifyRequest handshake message");
+ continue;
+ }
+
+ this.nextEpoch.Cookie = new byte[helloVerifyRequest.Cookie.Length];
+ helloVerifyRequest.Cookie.CopyTo(this.nextEpoch.Cookie);
+ this.nextEpoch.ClientRandom.FillWithRandom(this.random);
+
+ // We don't need to resend here. We already have the cookie so we already sent it once.
+ this.SendClientHello(isRetransmit: false);
+
+ break;
+
+ case HandshakeType.ServerHello:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHello)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHello handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 1)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHello MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ServerHello serverHello;
+ if (!ServerHello.Parse(out serverHello, payload))
+ {
+ this.logger.WriteError("Dropping malformed ServerHello message");
+ continue;
+ }
+
+ switch (serverHello.CipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.Handshake = new X25519EcdheRsaSha256(this.random);
+ break;
+
+ default:
+ this.logger.WriteError($"Dropping malformed ServerHello message. Unsupported CipherSuite({serverHello.CipherSuite})");
+ continue;
+ }
+
+ // Save server parameters
+ this.nextEpoch.SelectedCipherSuite = serverHello.CipherSuite;
+ serverHello.Random.CopyTo(this.nextEpoch.ServerRandom);
+ this.nextEpoch.State = HandshakeState.ExpectingCertificate;
+ this.nextEpoch.CertificateFragments.Clear();
+ this.nextEpoch.CertificatePayload = ByteSpan.Empty;
+
+#if DEBUG
+ this.logger.WriteVerbose($"ClientRandom: {this.nextEpoch.ClientRandom} ServerRandom: {this.nextEpoch.ServerRandom}");
+#endif
+
+ // Append ServerHelllo message to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.Certificate:
+ if (this.nextEpoch.State != HandshakeState.ExpectingCertificate)
+ {
+ this.logger.WriteError($"Dropping unexpected Certificate handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 2)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Certificate MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // If this is a fragmented message
+ if (handshake.FragmentLength != handshake.Length)
+ {
+ if (this.nextEpoch.CertificatePayload.Length != handshake.Length)
+ {
+ this.nextEpoch.CertificatePayload = new byte[handshake.Length];
+ this.nextEpoch.CertificateFragments.Clear();
+ }
+
+ // Should we add this fragment?
+ // According to the RFC 9147 Section 5.5, we are supposed to be tolerant of overlapping segments
+ // But if we... weren't... Hazel isn't going to change the fragment sizes. So would it really hurt?
+ // So let's just ignore that and assume that the sender always wants to send the same fragments.
+ if (IsFragmentOverlapping(this.nextEpoch.CertificateFragments, handshake.FragmentOffset, handshake.FragmentLength))
+ {
+ continue;
+ }
+
+ payload.CopyTo(this.nextEpoch.CertificatePayload.Slice((int)handshake.FragmentOffset, (int)handshake.FragmentLength));
+ this.nextEpoch.CertificateFragments.Add(new FragmentRange {Offset = (int)handshake.FragmentOffset, Length = (int)handshake.FragmentLength });
+ this.nextEpoch.CertificateFragments.Sort((FragmentRange lhs, FragmentRange rhs) => {
+ return lhs.Offset.CompareTo(rhs.Offset);
+ });
+
+ // Have we completed the message?
+ int currentOffset = 0;
+ bool valid = true;
+ foreach (FragmentRange range in this.nextEpoch.CertificateFragments)
+ {
+ if (range.Offset != currentOffset)
+ {
+ valid = false;
+ break;
+ }
+
+ currentOffset += range.Length;
+ }
+
+ if (currentOffset != this.nextEpoch.CertificatePayload.Length)
+ {
+ valid = false;
+ }
+
+ // Still waiting on more fragments?
+ if (!valid)
+ {
+ continue;
+ }
+
+ // Replace the message payload, and continue
+ this.nextEpoch.CertificateFragments.Clear();
+ payload = this.nextEpoch.CertificatePayload;
+ }
+
+ X509Certificate2 certificate;
+ if (!Certificate.Parse(out certificate, payload))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message");
+ continue;
+ }
+
+ // Verify the certificate is authenticate
+ if (!this.serverCertificates.Contains(certificate))
+ {
+ this.logger.WriteError("Dropping malformed Certificate message: Certificate not authentic");
+ continue;
+ }
+
+ RSA publicKey = certificate.PublicKey.Key as RSA;
+ if (publicKey == null)
+ {
+ this.logger.WriteError("Dropping malfomed Certificate message: Certificate is not RSA signed");
+ continue;
+ }
+
+ // Add the final Certificate message to the verification stream
+ Handshake fullCertificateHandhake = handshake;
+ fullCertificateHandhake.FragmentOffset = 0;
+ fullCertificateHandhake.FragmentLength = fullCertificateHandhake.Length;
+
+ ByteSpan serializedCertificateHandshake = new byte[Handshake.Size];
+ fullCertificateHandhake.Encode(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(serializedCertificateHandshake);
+ this.nextEpoch.VerificationStream.AddData(payload);
+
+ this.nextEpoch.ServerPublicKey = publicKey;
+ this.nextEpoch.State = HandshakeState.ExpectingServerKeyExchange;
+ break;
+
+ case HandshakeType.ServerKeyExchange:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerKeyExchange)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (this.nextEpoch.ServerPublicKey == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How are we processing a ServerKeyExchange message without a server public key?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No server public key");
+ continue;
+ }
+ else if (this.nextEpoch.Handshake == null)
+ {
+ ///NOTE(mendsley): This _should_ not
+ /// happen on a well-formed client
+ Debug.Assert(false, "How did we receive a ServerKeyExchange message without a handshake instance?");
+
+ this.logger.WriteError($"Dropping unexpected ServerKeyExchange handshake message: No key agreement interface");
+ continue;
+ }
+ else if (handshake.MessageSequence != 3)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerKeyExchange MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ ByteSpan sharedSecret = new byte[this.nextEpoch.Handshake.SharedKeySize()];
+ if (!this.nextEpoch.Handshake.VerifyServerMessageAndGenerateSharedKey(sharedSecret, payload, this.nextEpoch.ServerPublicKey))
+ {
+ this.logger.WriteError("Dropping malformed ServerKeyExchangeMessage");
+ return false;
+ }
+
+ // Generate the session master secret
+ ByteSpan randomSeed = new byte[2 * Random.Size];
+ this.nextEpoch.ClientRandom.CopyTo(randomSeed);
+ this.nextEpoch.ServerRandom.CopyTo(randomSeed.Slice(Random.Size));
+
+ const int MasterSecretSize = 48;
+ ByteSpan masterSecret = new byte[MasterSecretSize];
+ PrfSha256.ExpandSecret(
+ masterSecret
+ , sharedSecret
+ , PrfLabel.MASTER_SECRET
+ , randomSeed
+ );
+
+ // Create record protection for the upcoming epoch
+ switch (this.nextEpoch.SelectedCipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ this.nextEpoch.RecordProtection = new Aes128GcmRecordProtection(
+ masterSecret
+ , this.nextEpoch.ServerRandom
+ , this.nextEpoch.ClientRandom
+ );
+ break;
+
+ default:
+ ///NOTE(mendsley): this _should_ not
+ /// happen on a well-formed client.
+ Debug.Assert(false, "SeverHello processing already approved this ciphersuite");
+
+ this.logger.WriteError($"Dropping malformed ServerKeyExchangeMessage: Could not create record protection");
+ return false;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHelloDone;
+ this.nextEpoch.MasterSecret = masterSecret;
+
+ // Append ServerKeyExchange to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+ break;
+
+ case HandshakeType.ServerHelloDone:
+ if (this.nextEpoch.State != HandshakeState.ExpectingServerHelloDone)
+ {
+ this.logger.WriteError($"Dropping unexpected ServerHelloDone handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 4)
+ {
+ this.logger.WriteError($"Dropping bad-sequence ServerHelloDone MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+
+ // Append ServerHelloDone to the verification stream
+ this.nextEpoch.VerificationStream.AddData(originalPayload);
+
+ this.SendClientKeyExchangeFlight(isRetransmit: false);
+ break;
+
+ case HandshakeType.Finished:
+ if (this.nextEpoch.State != HandshakeState.ExpectingFinished)
+ {
+ this.logger.WriteError($"Dropping unexpected Finished handshake message State({this.nextEpoch.State})");
+ continue;
+ }
+ else if (payload.Length != Finished.Size)
+ {
+ this.logger.WriteError($"Dropping malformed Finished handshake message Size({payload.Length})");
+ continue;
+ }
+ else if (handshake.MessageSequence != 7)
+ {
+ this.logger.WriteError($"Dropping bad-sequence Finished MessageSequence({handshake.MessageSequence})");
+ continue;
+ }
+
+ // Verify the digest from the server
+ if (1 != Crypto.Const.ConstantCompareSpans(payload, this.nextEpoch.ServerVerification))
+ {
+ this.logger.WriteError("Dropping non-verified Finished handshake message");
+ return false;
+ }
+
+ ++this.nextEpoch.Epoch;
+ this.nextEpoch.State = HandshakeState.Established;
+ this.nextEpoch.NegotiationStartTime = DateTime.MinValue;
+ this.nextEpoch.NextPacketResendTime = DateTime.MinValue;
+ this.nextEpoch.ServerVerification.SecureClear();
+ this.nextEpoch.MasterSecret.SecureClear();
+
+ this.FlushQueuedApplicationData();
+ break;
+
+ // Drop messages we do not support
+ case HandshakeType.CertificateRequest:
+ case HandshakeType.HelloRequest:
+ this.logger.WriteError($"Dropping unsupported handshake message MessageType({handshake.MessageType})");
+ break;
+
+ // Drop messages that originate from the client
+ case HandshakeType.ClientHello:
+ case HandshakeType.ClientKeyExchange:
+ case HandshakeType.CertificateVerify:
+ this.logger.WriteError($"Dropping client handshake message MessageType({handshake.MessageType})");
+ break;
+ }
+ }
+
+ return true;
+ }
+
+ private bool IsFragmentOverlapping(List<FragmentRange> fragments, uint newOffset, uint newLength)
+ {
+ foreach (var frag in fragments)
+ {
+ // New fragment overlaps an existing one
+ if (newOffset <= frag.Offset
+ && frag.Offset < newOffset + newLength)
+ {
+ return true;
+ }
+
+ // Existing fragment overlaps this new one
+ if (frag.Offset <= newOffset
+ && newOffset < frag.Offset + frag.Length)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Send (resend) a ClientHello message to the server
+ /// </summary>
+ protected virtual void SendClientHello(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientHello in state: {this.nextEpoch.State}. Epoch: {this.epoch} Cookie: {this.nextEpoch.Cookie} Random: {this.nextEpoch.ClientRandom}");
+#endif
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.ClientProtocolVersion = DtlsVersion;
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.Session = new HazelDtlsSessionInfo(this.HazelSessionVersion);
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ clientHello.Encode(writer);
+
+ // If this is our first valid attempt at contacting the server:
+ // - Reset our verification stream
+ // - Write ClientHello to the verification stream
+ // - We next expect a ServerHello
+ //
+ // ClientHello+Cookie triggers many sequential packets in response
+ // It's important to make forward progress as the packets may be reordered in-flight
+ // But with enough resends, we will read them all in an appropriate order
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.Reset();
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(Record.Size, Handshake.Size + (int)handshake.Length)
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ }
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected void Test_SendClientHello(Func<ClientHello, ByteSpan, ByteSpan> encodeCallback)
+ {
+ // Reset our verification stream
+ this.nextEpoch.VerificationStream.Reset();
+
+ // Describe our ClientHello flight
+ ClientHello clientHello = new ClientHello();
+ clientHello.Random = this.nextEpoch.ClientRandom;
+ clientHello.Cookie = this.nextEpoch.Cookie;
+ clientHello.CipherSuites = new byte[2];
+ clientHello.CipherSuites.WriteBigEndian16((ushort)CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256);
+ clientHello.SupportedCurves = new byte[2];
+ clientHello.SupportedCurves.WriteBigEndian16((ushort)NamedCurve.x25519);
+
+ Handshake handshake = new Handshake();
+ handshake.MessageType = HandshakeType.ClientHello;
+ handshake.Length = (uint)clientHello.CalculateSize();
+ handshake.MessageSequence = 0;
+ handshake.FragmentOffset = 0;
+ handshake.FragmentLength = handshake.Length;
+
+ // Describe the record
+ int plaintextLength = (int)(Handshake.Size + handshake.Length);
+ Record outgoingRecord = new Record();
+ outgoingRecord.ContentType = ContentType.Handshake;
+ outgoingRecord.ProtocolVersion = DtlsVersion;
+ outgoingRecord.Epoch = this.epoch;
+ outgoingRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ outgoingRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(plaintextLength);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ // Convert the record to wire format
+ ByteSpan packet = new byte[Record.Size + outgoingRecord.Length];
+ ByteSpan writer = packet;
+ outgoingRecord.Encode(packet);
+ writer = writer.Slice(Record.Size);
+ handshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ writer = encodeCallback(clientHello, writer);
+
+ // Write ClientHello to the verification stream
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)handshake.Length
+ )
+ );
+
+ // Protect the record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, outgoingRecord.Length),
+ packet.Slice(Record.Size, plaintextLength),
+ ref outgoingRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingServerHello;
+ if (this.nextEpoch.NegotiationStartTime == DateTime.MinValue) this.nextEpoch.NegotiationStartTime = DateTime.UtcNow;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ /// <summary>
+ /// Send (resend) the ClientKeyExchange flight
+ /// </summary>
+ /// <param name="isRetransmit">
+ /// True if this is a retransmit of the flight. Otherwise,
+ /// false
+ /// </param>
+ protected virtual void SendClientKeyExchangeFlight(bool isRetransmit)
+ {
+#if DEBUG
+ var verb = isRetransmit ? "Resending" : "Sending";
+ this.logger.WriteVerbose($"{verb} ClientKeyExchangeFlight in state: {this.nextEpoch.State}");
+#endif
+ if (this.nextEpoch.State == HandshakeState.Established)
+ {
+ return;
+ }
+
+ // Describe our flight
+ Handshake keyExchangeHandshake = new Handshake();
+ keyExchangeHandshake.MessageType = HandshakeType.ClientKeyExchange;
+ keyExchangeHandshake.Length = (ushort)this.nextEpoch.Handshake.CalculateClientMessageSize();
+ keyExchangeHandshake.MessageSequence = 5;
+ keyExchangeHandshake.FragmentOffset = 0;
+ keyExchangeHandshake.FragmentLength = keyExchangeHandshake.Length;
+
+ Record keyExchangeRecord = new Record();
+ keyExchangeRecord.ContentType = ContentType.Handshake;
+ keyExchangeRecord.ProtocolVersion = DtlsVersion;
+ keyExchangeRecord.Epoch = this.epoch;
+ keyExchangeRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ keyExchangeRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)keyExchangeHandshake.Length);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Record changeCipherSpecRecord = new Record();
+ changeCipherSpecRecord.ContentType = ContentType.ChangeCipherSpec;
+ changeCipherSpecRecord.ProtocolVersion = DtlsVersion;
+ changeCipherSpecRecord.Epoch = this.epoch;
+ changeCipherSpecRecord.SequenceNumber = this.currentEpoch.NextOutgoingSequence;
+ changeCipherSpecRecord.Length = (ushort)this.currentEpoch.RecordProtection.GetEncryptedSize(ChangeCipherSpec.Size);
+ ++this.currentEpoch.NextOutgoingSequence;
+
+ Handshake finishedHandshake = new Handshake();
+ finishedHandshake.MessageType = HandshakeType.Finished;
+ finishedHandshake.Length = Finished.Size;
+ finishedHandshake.MessageSequence = 6;
+ finishedHandshake.FragmentOffset = 0;
+ finishedHandshake.FragmentLength = finishedHandshake.Length;
+
+ Record finishedRecord = new Record();
+ finishedRecord.ContentType = ContentType.Handshake;
+ finishedRecord.ProtocolVersion = DtlsVersion;
+ finishedRecord.Epoch = this.nextEpoch.Epoch;
+ finishedRecord.SequenceNumber = this.nextEpoch.NextOutgoingSequence;
+ finishedRecord.Length = (ushort)this.nextEpoch.RecordProtection.GetEncryptedSize(Handshake.Size + (int)finishedHandshake.Length);
+ ++this.nextEpoch.NextOutgoingSequence;
+
+ // Encode flight to wire format
+ int packetLength = 0
+ + Record.Size + keyExchangeRecord.Length
+ + Record.Size + changeCipherSpecRecord.Length
+ + Record.Size + finishedRecord.Length;
+ ;
+ ByteSpan packet = new byte[packetLength];
+ ByteSpan writer = packet;
+
+ keyExchangeRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ keyExchangeHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+ this.nextEpoch.Handshake.EncodeClientKeyExchangeMessage(writer);
+
+ ByteSpan startOfChangeCipherSpecRecord = packet.Slice(Record.Size + keyExchangeRecord.Length);
+ writer = startOfChangeCipherSpecRecord;
+ changeCipherSpecRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ ChangeCipherSpec.Encode(writer);
+ writer = writer.Slice(ChangeCipherSpec.Size);
+
+ ByteSpan startOfFinishedRecord = startOfChangeCipherSpecRecord.Slice(Record.Size + changeCipherSpecRecord.Length);
+ writer = startOfFinishedRecord;
+ finishedRecord.Encode(writer);
+ writer = writer.Slice(Record.Size);
+ finishedHandshake.Encode(writer);
+ writer = writer.Slice(Handshake.Size);
+
+ // Interject here to writer our client key exchange
+ // message into the verification stream
+ if (!isRetransmit)
+ {
+ this.nextEpoch.VerificationStream.AddData(
+ packet.Slice(
+ Record.Size
+ , Handshake.Size + (int)keyExchangeHandshake.Length
+ )
+ );
+ }
+
+ // Calculate the hash of the verification stream
+ ByteSpan handshakeHash = new byte[Sha256Stream.DigestSize];
+ this.nextEpoch.VerificationStream.CopyOrCalculateFinalHash(handshakeHash);
+
+ // Expand our master secret into Finished digests for the client and server
+ PrfSha256.ExpandSecret(
+ this.nextEpoch.ServerVerification
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.SERVER_FINISHED
+ , handshakeHash
+ );
+
+ PrfSha256.ExpandSecret(
+ writer.Slice(0, Finished.Size)
+ , this.nextEpoch.MasterSecret
+ , PrfLabel.CLIENT_FINISHED
+ , handshakeHash
+ );
+ writer = writer.Slice(Finished.Size);
+
+ // Protect the ClientKeyExchange record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ packet.Slice(Record.Size, keyExchangeRecord.Length),
+ packet.Slice(Record.Size, Handshake.Size + (int)keyExchangeHandshake.Length),
+ ref keyExchangeRecord
+ );
+
+ // Protect the ChangeCipherSpec record
+ this.currentEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfChangeCipherSpecRecord.Slice(Record.Size, changeCipherSpecRecord.Length),
+ startOfChangeCipherSpecRecord.Slice(Record.Size, ChangeCipherSpec.Size),
+ ref changeCipherSpecRecord
+ );
+
+ // Protect the Finished record
+ this.nextEpoch.RecordProtection.EncryptClientPlaintext(
+ startOfFinishedRecord.Slice(Record.Size, finishedRecord.Length),
+ startOfFinishedRecord.Slice(Record.Size, Handshake.Size + (int)finishedHandshake.Length),
+ ref finishedRecord
+ );
+
+ this.nextEpoch.State = HandshakeState.ExpectingChangeCipherSpec;
+ this.nextEpoch.NextPacketResendTime = DateTime.UtcNow + this.handshakeResendTimeout;
+#if DEBUG
+ if (DropClientKeyExchangeFlight())
+ {
+ return;
+ }
+#endif
+ base.WriteBytesToConnection(packet.GetUnderlyingArray(), packet.Length);
+ }
+
+ protected virtual bool DropClientKeyExchangeFlight()
+ {
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
new file mode 100644
index 0000000..f840053
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Handshake.cs
@@ -0,0 +1,734 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Handshake message type
+ /// </summary>
+ public enum HandshakeType : byte
+ {
+ HelloRequest = 0,
+ ClientHello = 1,
+ ServerHello = 2,
+ HelloVerifyRequest = 3,
+ Certificate = 11,
+ ServerKeyExchange = 12,
+ CertificateRequest = 13,
+ ServerHelloDone = 14,
+ CertificateVerify = 15,
+ ClientKeyExchange = 16,
+ Finished = 20,
+ }
+
+ /// <summary>
+ /// List of cipher suites
+ /// </summary>
+ public enum CipherSuite
+ {
+ TLS_NULL_WITH_NULL_NULL = 0x0000,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F,
+ }
+
+ /// <summary>
+ /// List of compression methods
+ /// </summary>
+ public enum CompressionMethod : byte
+ {
+ Null = 0,
+ }
+
+ /// <summary>
+ /// Extension type
+ /// </summary>
+ public enum ExtensionType : ushort
+ {
+ EllipticCurves = 10,
+ }
+
+ /// <summary>
+ /// Named curves
+ /// </summary>
+ public enum NamedCurve : ushort
+ {
+ Reserved = 0,
+ secp256r1 = 23,
+ x25519 = 29,
+ }
+
+ /// <summary>
+ /// Elliptic curve type
+ /// </summary>
+ public enum ECCurveType : byte
+ {
+ NamedCurve = 3,
+ }
+
+ /// <summary>
+ /// Hash algorithms
+ /// </summary>
+ public enum HashAlgorithm : byte
+ {
+ None = 0,
+ Sha256 = 4,
+ }
+
+ /// <summary>
+ /// Signature algorithms
+ /// </summary>
+ public enum SignatureAlgorithm : byte
+ {
+ Anonymous = 0,
+ RSA = 1,
+ ECDSA = 3,
+ }
+
+ /// <summary>
+ /// Random state for entropy
+ /// </summary>
+ public struct Random
+ {
+ public const int Size = 0
+ + 4 // gmt_unix_time
+ + 28 // random_bytes
+ ;
+ }
+
+ /// <summary>
+ /// Encode/decode handshake protocol header
+ /// </summary>
+ public struct Handshake
+ {
+ public HandshakeType MessageType;
+ public uint Length;
+ public ushort MessageSequence;
+ public uint FragmentOffset;
+ public uint FragmentLength;
+
+ public const int Size = 12;
+
+ /// <summary>
+ /// Parse a Handshake protocol header from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode a handshake header. Otherwise false</returns>
+ public static bool Parse(out Handshake header, ByteSpan span)
+ {
+ header = new Handshake();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ header.MessageType = (HandshakeType)span[0];
+ header.Length = span.ReadBigEndian24(1);
+ header.MessageSequence = span.ReadBigEndian16(4);
+ header.FragmentOffset = span.ReadBigEndian24(6);
+ header.FragmentLength = span.ReadBigEndian24(9);
+ return true;
+ }
+
+ /// <summary>
+ /// Encode the Handshake protocol header to wire format
+ /// </summary>
+ /// <param name="span"></param>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.MessageType;
+ span.WriteBigEndian24(this.Length, 1);
+ span.WriteBigEndian16(this.MessageSequence, 4);
+ span.WriteBigEndian24(this.FragmentOffset, 6);
+ span.WriteBigEndian24(this.FragmentLength, 9);
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode ClientHello Handshake message
+ /// </summary>
+ public struct ClientHello
+ {
+ public ProtocolVersion ClientProtocolVersion;
+ public ByteSpan Random;
+ public ByteSpan Cookie;
+ public HazelDtlsSessionInfo Session;
+ public ByteSpan CipherSuites;
+ public ByteSpan SupportedCurves;
+
+ public const int MinSize = 0
+ + 2 // client_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 1 // cookie (size)
+ + 2 // cipher_suites (size)
+ + 1 // compression_methods (size)
+ + 1 // compression_method[0] (NULL)
+
+ + 2 // extensions size
+
+ + 0 // NamedCurveList extensions[0]
+ + 2 // extensions[0].extension_type
+ + 2 // extensions[0].extension_data (length)
+ + 2 // extensions[0].named_curve_list (size)
+ ;
+
+ /// <summary>
+ /// Calculate the size in bytes required for the ClientHello payload
+ /// </summary>
+ /// <returns></returns>
+ public int CalculateSize()
+ {
+ return MinSize
+ + this.Session.PayloadSize
+ + this.Cookie.Length
+ + this.CipherSuites.Length
+ + this.SupportedCurves.Length
+ ;
+ }
+
+ /// <summary>
+ /// Parse a Handshake ClientHello payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the ClientHello message. Otherwise false</returns>
+ public static bool Parse(out ClientHello result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new ClientHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ClientProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ if (expectedProtocolVersion.HasValue && result.ClientProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ byte cookieSize = span[0];
+ if (span.Length < 1 + cookieSize)
+ {
+ return false;
+ }
+ result.Cookie = span.Slice(1, cookieSize);
+ span = span.Slice(1 + cookieSize);
+
+ ushort cipherSuiteSize = span.ReadBigEndian16();
+ if (span.Length < 2 + cipherSuiteSize)
+ {
+ return false;
+ }
+ else if (cipherSuiteSize % 2 != 0)
+ {
+ return false;
+ }
+ result.CipherSuites = span.Slice(2, cipherSuiteSize);
+ span = span.Slice(2 + cipherSuiteSize);
+
+ int compressionMethodsSize = span[0];
+ bool foundNullCompressionMethod = false;
+ for (int ii = 0; ii != compressionMethodsSize; ++ii)
+ {
+ if (span[1+ii] == (byte)CompressionMethod.Null)
+ {
+ foundNullCompressionMethod = true;
+ break;
+ }
+ }
+
+ if (!foundNullCompressionMethod
+ || span.Length < 1 + compressionMethodsSize)
+ {
+ return false;
+ }
+
+ span = span.Slice(1 + compressionMethodsSize);
+
+ // Parse extensions
+ if (span.Length > 0)
+ {
+ if (span.Length < 2)
+ {
+ return false;
+ }
+
+ ushort extensionsSize = span.ReadBigEndian16();
+ span = span.Slice(2);
+ if (span.Length != extensionsSize)
+ {
+ return false;
+ }
+
+ while (span.Length > 0)
+ {
+ // Parse extension header
+ if (span.Length < 4)
+ {
+ return false;
+ }
+
+ ExtensionType extensionType = (ExtensionType)span.ReadBigEndian16(0);
+ ushort extensionLength = span.ReadBigEndian16(2);
+
+ if (span.Length < 4 + extensionLength)
+ {
+ return false;
+ }
+
+ ByteSpan extensionData = span.Slice(4, extensionLength);
+ span = span.Slice(4 + extensionLength);
+ result.ParseExtension(extensionType, extensionData);
+ }
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Decode a ClientHello extension
+ /// </summary>
+ /// <param name="extensionType">Extension type</param>
+ /// <param name="extensionData">Extension data</param>
+ private void ParseExtension(ExtensionType extensionType, ByteSpan extensionData)
+ {
+ switch (extensionType)
+ {
+ case ExtensionType.EllipticCurves:
+ if (extensionData.Length % 2 != 0)
+ {
+ break;
+ }
+ else if (extensionData.Length < 2)
+ {
+ break;
+ }
+
+ ushort namedCurveSize = extensionData.ReadBigEndian16(0);
+ if (namedCurveSize % 2 != 0)
+ {
+ break;
+ }
+
+ this.SupportedCurves = extensionData.Slice(2, namedCurveSize);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified cipher suite
+ /// </summary>
+ public bool ContainsCipherSuite(CipherSuite cipherSuite)
+ {
+ ByteSpan iterator = this.CipherSuites;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)cipherSuite)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Determines if the ClientHello message advertises support
+ /// for the specified curve
+ /// </summary>
+ public bool ContainsCurve(NamedCurve curve)
+ {
+ ByteSpan iterator = this.SupportedCurves;
+ while (iterator.Length >= 2)
+ {
+ if (iterator.ReadBigEndian16() == (ushort)curve)
+ {
+ return true;
+ }
+
+ iterator = iterator.Slice(2);
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Encode Handshake ClientHello payload to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span.WriteBigEndian16((ushort)this.ClientProtocolVersion);
+ span = span.Slice(2);
+
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span[0] = (byte)this.Cookie.Length;
+ this.Cookie.CopyTo(span.Slice(1));
+ span = span.Slice(1 + this.Cookie.Length);
+
+ span.WriteBigEndian16((ushort)this.CipherSuites.Length);
+ this.CipherSuites.CopyTo(span.Slice(2));
+ span = span.Slice(2 + this.CipherSuites.Length);
+
+ span[0] = 1;
+ span[1] = (byte)CompressionMethod.Null;
+ span = span.Slice(2);
+
+ // Extensions size
+ span.WriteBigEndian16((ushort)(6 + this.SupportedCurves.Length));
+ span = span.Slice(2);
+
+ // Supported curves extension
+ span.WriteBigEndian16((ushort)ExtensionType.EllipticCurves);
+ span.WriteBigEndian16((ushort)(2 + this.SupportedCurves.Length), 2);
+ span.WriteBigEndian16((ushort)this.SupportedCurves.Length, 4);
+ this.SupportedCurves.CopyTo(span.Slice(6));
+ }
+ }
+
+ /// <summary>
+ /// Encode/Decode session information in ClientHello
+ /// </summary>
+ public struct HazelDtlsSessionInfo
+ {
+ public const byte CurrentClientSessionSize = 1;
+ public const byte CurrentClientSessionVersion = 1;
+
+ public byte FullSize => (byte)(1 + this.PayloadSize);
+ public byte PayloadSize;
+ public byte Version;
+
+ public HazelDtlsSessionInfo(byte version)
+ {
+ this.Version = version;
+ switch (version)
+ {
+ case 0: // Does not write version byte
+ this.PayloadSize = 0;
+ return;
+ case 1: // Writes version byte only
+ this.PayloadSize = 1;
+ return;
+ }
+
+ throw new ArgumentOutOfRangeException("Unimplemented Hazel session version");
+ }
+
+ public void Encode(ByteSpan writer)
+ {
+ writer[0] = this.PayloadSize;
+
+ if (this.Version > 0)
+ {
+ writer[1] = this.Version;
+ }
+ }
+
+ public static bool Parse(out HazelDtlsSessionInfo result, ByteSpan reader)
+ {
+ result = new HazelDtlsSessionInfo();
+ if (reader.Length < 1)
+ {
+ return false;
+ }
+
+ result.PayloadSize = reader[0];
+
+ // Back compat, length may be zero, version defaults to 0.
+ if (result.PayloadSize == 0)
+ {
+ result.Version = 0;
+ return true;
+ }
+
+ // Forward compat, if length > 1, ignore the rest
+ result.Version = reader[1];
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake HelloVerifyRequest message
+ /// </summary>
+ public struct HelloVerifyRequest
+ {
+ public const int CookieSize = 20;
+ public const int Size = 0
+ + 2 // server_version
+ + 1 // cookie (size)
+ + CookieSize // cookie (data)
+ ;
+
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Cookie;
+
+ /// <summary>
+ /// Parse a Handshake HelloVerifyRequest payload from wire
+ /// format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the HelloVerifyRequest
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out HelloVerifyRequest result, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ result = new HelloVerifyRequest();
+ if (span.Length < 3)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(0);
+ if (expectedProtocolVersion.HasValue && result.ServerProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ byte cookieSize = span[2];
+ span = span.Slice(3);
+
+ if (span.Length < cookieSize)
+ {
+ return false;
+ }
+
+ result.Cookie = span;
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a HelloVerifyRequest payload to wire format
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static void Encode(ByteSpan span, EndPoint peerAddress, HMAC hmac, ProtocolVersion protocolVersion)
+ {
+ ByteSpan cookie = ComputeAddressMac(peerAddress, hmac);
+
+ span.WriteBigEndian16((ushort)protocolVersion);
+ span[2] = (byte)CookieSize;
+ cookie.CopyTo(span.Slice(3));
+ }
+
+ /// <summary>
+ /// Generate an HMAC for a peer address
+ /// </summary>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ public static ByteSpan ComputeAddressMac(EndPoint peerAddress, HMAC hmac)
+ {
+ SocketAddress address = peerAddress.Serialize();
+ byte[] data = new byte[address.Size];
+ for (int ii = 0, nn = data.Length; ii != nn; ++ii)
+ {
+ data[ii] = address[ii];
+ }
+
+ ///NOTE(mendsley): Lame that we need to allocate+copy here
+ ByteSpan signature = hmac.ComputeHash(data);
+ return signature.Slice(0, CookieSize);
+ }
+
+ /// <summary>
+ /// Verify a client's cookie was signed by our listener
+ /// </summary>
+ /// <param name="cookie">Wire format cookie</param>
+ /// <param name="peerAddress">Address of the remote peer</param>
+ /// <param name="hmac">Listener HMAC signature provider</param>
+ /// <returns>True if the cookie is valid. Otherwise false</returns>
+ public static bool VerifyCookie(ByteSpan cookie, EndPoint peerAddress, HMAC hmac)
+ {
+ if (cookie.Length != CookieSize)
+ {
+ return false;
+ }
+
+ ByteSpan expectedHash = ComputeAddressMac(peerAddress, hmac);
+ if (expectedHash.Length != cookie.Length)
+ {
+ return false;
+ }
+
+ return (1 == Crypto.Const.ConstantCompareSpans(cookie, expectedHash));
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake ServerHello message
+ /// </summary>
+ public struct ServerHello
+ {
+ public ProtocolVersion ServerProtocolVersion;
+ public ByteSpan Random;
+ public CipherSuite CipherSuite;
+ public HazelDtlsSessionInfo Session;
+
+ public const int MinSize = 0
+ + 2 // server_version
+ + Dtls.Random.Size // random
+ + 1 // session_id (size)
+ + 2 // cipher_suite
+ + 1 // compression_method
+ ;
+
+ public int Size => MinSize + Session.PayloadSize;
+
+ /// <summary>
+ /// Parse a Handshake ServerHello payload from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully decode the ServerHello
+ /// message. Otherwise false.
+ /// </returns>
+ public static bool Parse(out ServerHello result, ByteSpan span)
+ {
+ result = new ServerHello();
+ if (span.Length < MinSize)
+ {
+ return false;
+ }
+
+ result.ServerProtocolVersion = (ProtocolVersion)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ result.Random = span.Slice(0, Dtls.Random.Size);
+ span = span.Slice(Dtls.Random.Size);
+
+ if (!HazelDtlsSessionInfo.Parse(out result.Session, span))
+ {
+ return false;
+ }
+
+ span = span.Slice(result.Session.FullSize);
+
+ result.CipherSuite = (CipherSuite)span.ReadBigEndian16();
+ span = span.Slice(2);
+
+ CompressionMethod compressionMethod = (CompressionMethod)span[0];
+ if (compressionMethod != CompressionMethod.Null)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode Handshake ServerHello to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ Debug.Assert(this.Random.Length == Dtls.Random.Size);
+
+ span.WriteBigEndian16((ushort)this.ServerProtocolVersion, 0);
+ span = span.Slice(2);
+
+ this.Random.CopyTo(span);
+ span = span.Slice(Dtls.Random.Size);
+
+ this.Session.Encode(span);
+ span = span.Slice(this.Session.FullSize);
+
+ span.WriteBigEndian16((ushort)this.CipherSuite);
+ span = span.Slice(2);
+
+ span[0] = (byte)CompressionMethod.Null;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Certificate message
+ /// </summary>
+ public struct Certificate
+ {
+ /// <summary>
+ /// Encode a certificate to wire formate
+ /// </summary>
+ public static ByteSpan Encode(X509Certificate2 certificate)
+ {
+ ByteSpan certData = certificate.GetRawCertData();
+ int totalSize = certData.Length + 3 + 3;
+
+ ByteSpan result = new byte[totalSize];
+
+ ByteSpan writer = result;
+ writer.WriteBigEndian24((uint)certData.Length + 3);
+ writer = writer.Slice(3);
+ writer.WriteBigEndian24((uint)certData.Length);
+ writer = writer.Slice(3);
+
+ certData.CopyTo(writer);
+ return result;
+ }
+
+ /// <summary>
+ /// Parse a Handshake Certificate payload from wire format
+ /// </summary>
+ /// <returns>True if we successfully decode the Certificate message. Otherwise false</returns>
+ public static bool Parse(out X509Certificate2 certificate, ByteSpan span)
+ {
+ certificate = null;
+ if (span.Length < 6)
+ {
+ return false;
+ }
+
+ uint totalSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+
+ if (span.Length < totalSize)
+ {
+ return false;
+ }
+
+ uint certificateSize = span.ReadBigEndian24();
+ span = span.Slice(3);
+ if (span.Length < certificateSize)
+ {
+ return false;
+ }
+
+ byte[] rawData = new byte[certificateSize];
+ span.CopyTo(rawData, 0);
+ try
+ {
+ certificate = new X509Certificate2(rawData);
+ }
+ catch (Exception)
+ {
+ return false;
+ }
+
+ return true;
+ }
+ }
+
+ /// <summary>
+ /// Encode/decode Handshake Finished message
+ /// </summary>
+ public struct Finished
+ {
+ public const int Size = 12;
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
new file mode 100644
index 0000000..eedd977
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IHandshakeCipherSuite.cs
@@ -0,0 +1,63 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for the handshake portion of
+ /// the connection.
+ /// </summary>
+ public interface IHandshakeCipherSuite : IDisposable
+ {
+ /// <summary>
+ /// Gets the size of the shared key
+ /// </summary>
+ /// <returns>Size of the shared key in bytes </returns>
+ int SharedKeySize();
+
+ /// <summary>
+ /// Calculate the size of the ServerKeyExchnage message
+ /// </summary>
+ /// <param name="privateKey">
+ /// Private key that will be used to sign the message
+ /// </param>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateServerMessageSize(object privateKey);
+
+ /// <summary>
+ /// Encodes the ServerKeyExchange message
+ /// </summary>
+ /// <param name="privateKey">Private key to use for signing</param>
+ void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey);
+
+ /// <summary>
+ /// Verifies the authenticity of a server key exchange
+ /// message and calculates the shared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the authenticity has been validated and a shared key
+ /// was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey);
+
+ /// <summary>
+ /// Calculate the size of the ClientKeyExchange message
+ /// </summary>
+ /// <returns>Size of the message in bytes</returns>
+ int CalculateClientMessageSize();
+
+ /// <summary>
+ /// Encodes the ClientKeyExchangeMessage
+ /// </summary>
+ void EncodeClientKeyExchangeMessage(ByteSpan output);
+
+ /// <summary>
+ /// Verifies the validity of a client key exchange message
+ /// and calculats the hsared secret.
+ /// </summary>
+ /// <returns>
+ /// True if the client exchange message is valid and a
+ /// shared key was generated. Otherwise, false.
+ /// </returns>
+ bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage);
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
new file mode 100644
index 0000000..cbee1b0
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/IRecordProtection.cs
@@ -0,0 +1,84 @@
+using System;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS cipher suite interface for protection of record payload.
+ /// </summary>
+ public interface IRecordProtection : IDisposable
+ {
+ /// <summary>
+ /// Calculate the size of an encrypted plaintext
+ /// </summary>
+ /// <param name="dataSize">Size of plaintext in bytes</param>
+ /// <returns>Size of encrypted ciphertext in bytes</returns>
+ int GetEncryptedSize(int dataSize);
+
+ /// <summary>
+ /// Calculate the size of decrypted ciphertext
+ /// </summary>
+ /// <param name="dataSize">Size of ciphertext in bytes</param>
+ /// <returns>Size of decrypted plaintext in bytes</returns>
+ int GetDecryptedSize(int dataSize);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Encrypt a plaintext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output ciphertext</param>
+ /// <param name="input">Input plaintext</param>
+ /// <param name="record">Parent DTLS record</param>
+ void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with server keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record);
+
+ /// <summary>
+ /// Decrypt a ciphertext intput with client keys
+ ///
+ /// Output may overlap with input.
+ /// </summary>
+ /// <param name="output">Output plaintext</param>
+ /// <param name="input">Input ciphertext</param>
+ /// <param name="record">Parent DTLS record</param>
+ /// <returns>True if the input was authenticated and decrypted. Otherwise false</returns>
+ bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record);
+ }
+
+ /// <summary>
+ /// Factory to create record protection from cipher suite identifiers
+ /// </summary>
+ public sealed class RecordProtectionFactory
+ {
+ public static IRecordProtection Create(CipherSuite cipherSuite, ByteSpan masterSecret, ByteSpan serverRandom, ByteSpan clientRandom)
+ {
+ switch (cipherSuite)
+ {
+ case CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
+ return new Aes128GcmRecordProtection(masterSecret, serverRandom, clientRandom);
+
+ default:
+ return null;
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
new file mode 100644
index 0000000..76fa132
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/NullRecordProtection.cs
@@ -0,0 +1,66 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Passthrough record protection implementaion
+ /// </summary>
+ public class NullRecordProtection : IRecordProtection
+ {
+ public readonly static NullRecordProtection Instance = new NullRecordProtection();
+
+ public void Dispose()
+ {
+ }
+
+ public int GetEncryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public int GetDecryptedSize(int dataSize)
+ {
+ return dataSize;
+ }
+
+ public void EncryptServerPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public void EncryptClientPlaintext(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ }
+
+ public bool DecryptCiphertextFromServer(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ public bool DecryptCiphertextFromClient(ByteSpan output, ByteSpan input, ref Record record)
+ {
+ CopyMaybeOverlappingSpans(output, input);
+ return true;
+ }
+
+ private static void CopyMaybeOverlappingSpans(ByteSpan output, ByteSpan input)
+ {
+ // Early out if the ranges `output` is equal to `input`
+ if (output.GetUnderlyingArray() == input.GetUnderlyingArray())
+ {
+ if (output.Offset == input.Offset && output.Length == input.Length)
+ {
+ return;
+ }
+ }
+
+ input.CopyTo(output);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
new file mode 100644
index 0000000..1fa7f17
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/PrfSha256.cs
@@ -0,0 +1,84 @@
+using System.Text;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// Common Psuedorandom Function labels for TLS
+ /// </summary>
+ public struct PrfLabel
+ {
+ public static readonly ByteSpan MASTER_SECRET = LabelToBytes("master secert");
+ public static readonly ByteSpan KEY_EXPANSION = LabelToBytes("key expansion");
+ public static readonly ByteSpan CLIENT_FINISHED = LabelToBytes("client finished");
+ public static readonly ByteSpan SERVER_FINISHED = LabelToBytes("server finished");
+
+ /// <summary>
+ /// Convert a text label to a byte sequence
+ /// </summary>
+ public static ByteSpan LabelToBytes(string label)
+ {
+ return Encoding.ASCII.GetBytes(label);
+ }
+ }
+
+ /// <summary>
+ /// The P_SHA256 Psuedorandom Function
+ /// </summary>
+ public struct PrfSha256
+ {
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, string label, ByteSpan initialSeed)
+ {
+ ExpandSecret(output, key, PrfLabel.LabelToBytes(label), initialSeed);
+ }
+
+ /// <summary>
+ /// Expand a secret key
+ /// </summary>
+ /// <param name="output">Output span. Length determines how much data to generate</param>
+ /// <param name="key">Original key to expand</param>
+ /// <param name="label">Label (treated as a salt)</param>
+ /// <param name="initialSeed">Seed for expansion (treated as a salt)</param>
+ public static void ExpandSecret(ByteSpan output, ByteSpan key, ByteSpan label, ByteSpan initialSeed)
+ {
+ ByteSpan writer = output;
+
+ byte[] roundSeed = new byte[label.Length + initialSeed.Length];
+ label.CopyTo(roundSeed);
+ initialSeed.CopyTo(roundSeed, label.Length);
+
+ byte[] hashA = roundSeed;
+
+ using (HMACSHA256 hmac = new HMACSHA256(key.ToArray()))
+ {
+ byte[] input = new byte[hmac.OutputBlockSize + roundSeed.Length];
+ new ByteSpan(roundSeed).CopyTo(input, hmac.OutputBlockSize);
+
+ while (writer.Length > 0)
+ {
+ // Update hashA
+ hashA = hmac.ComputeHash(hashA);
+
+ // generate hash input
+ new ByteSpan(hashA).CopyTo(input);
+
+ ByteSpan roundOutput = hmac.ComputeHash(input);
+ if (roundOutput.Length > writer.Length)
+ {
+ roundOutput = roundOutput.Slice(0, writer.Length);
+ }
+
+ roundOutput.CopyTo(writer);
+ writer = writer.Slice(roundOutput.Length);
+ }
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/Record.cs b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
new file mode 100644
index 0000000..23eaa95
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/Record.cs
@@ -0,0 +1,123 @@
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// DTLS version constants
+ /// </summary>
+ public enum ProtocolVersion : ushort
+ {
+ /// <summary>
+ /// Use to obfuscate DTLS as regular UDP packets
+ /// </summary>
+ UDP = 0,
+
+ /// <summary>
+ /// DTLS 1.2
+ /// </summary>
+ DTLS1_2 = 0xFEFD,
+ }
+
+ /// <summary>
+ /// DTLS record content type
+ /// </summary>
+ public enum ContentType : byte
+ {
+ ChangeCipherSpec = 20,
+ Alert = 21,
+ Handshake = 22,
+ ApplicationData = 23,
+ }
+
+ /// <summary>
+ /// Encode/decode DTLS record header
+ /// </summary>
+ public struct Record
+ {
+ public ContentType ContentType;
+ public ProtocolVersion ProtocolVersion;
+ public ushort Epoch;
+ public ulong SequenceNumber;
+ public ushort Length;
+
+ public const int Size = 13;
+
+ /// <summary>
+ /// Parse a DTLS record from wire format
+ /// </summary>
+ /// <returns>True if we successfully parse the record header. Otherwise false</returns>
+ public static bool Parse(out Record record, ProtocolVersion? expectedProtocolVersion, ByteSpan span)
+ {
+ record = new Record();
+
+ if (span.Length < Size)
+ {
+ return false;
+ }
+
+ record.ContentType = (ContentType)span[0];
+ record.ProtocolVersion = (ProtocolVersion)span.ReadBigEndian16(1);
+ record.Epoch = span.ReadBigEndian16(3);
+ record.SequenceNumber = span.ReadBigEndian48(5);
+ record.Length = span.ReadBigEndian16(11);
+
+ if (expectedProtocolVersion.HasValue && record.ProtocolVersion != expectedProtocolVersion.Value)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a DTLS record to wire format
+ /// </summary>
+ public void Encode(ByteSpan span)
+ {
+ span[0] = (byte)this.ContentType;
+ span.WriteBigEndian16((ushort)this.ProtocolVersion, 1);
+ span.WriteBigEndian16(this.Epoch, 3);
+ span.WriteBigEndian48(this.SequenceNumber, 5);
+ span.WriteBigEndian16(this.Length, 11);
+ }
+ }
+
+ public struct ChangeCipherSpec
+ {
+ public const int Size = 1;
+
+ enum Value : byte
+ {
+ ChangeCipherSpec = 1,
+ }
+
+ /// <summary>
+ /// Parse a ChangeCipherSpec record from wire format
+ /// </summary>
+ /// <returns>
+ /// True if we successfully parse the ChangeCipherSpec
+ /// record. Otherwise, false.
+ /// </returns>
+ public static bool Parse(ByteSpan span)
+ {
+ if (span.Length != 1)
+ {
+ return false;
+ }
+
+ Value value = (Value)span[0];
+ if (value != Value.ChangeCipherSpec)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// <summary>
+ /// Encode a ChangeCipherSpec record to wire format
+ /// </summary>
+ public static void Encode(ByteSpan span)
+ {
+ span[0] = (byte)Value.ChangeCipherSpec;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
new file mode 100644
index 0000000..38da061
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/ThreadedHmacHelper.cs
@@ -0,0 +1,159 @@
+using System;
+using System.Collections.Concurrent;
+using System.Security.Cryptography;
+using System.Threading;
+
+namespace Hazel.Dtls
+{
+ internal class ThreadedHmacHelper : IDisposable
+ {
+ private class ThreadHmacs
+ {
+ public HMAC currentHmac;
+ public HMAC previousHmac;
+ public HMAC hmacToDispose;
+ }
+
+ private static readonly int CookieHmacRotationTimeout = (int)TimeSpan.FromHours(1.0).TotalMilliseconds;
+
+ private readonly ILogger logger;
+ private readonly ConcurrentDictionary<int, ThreadHmacs> hmacs;
+ private Timer rotateKeyTimer;
+ private RandomNumberGenerator cryptoRandom;
+ private byte[] currentHmacKey;
+
+ public ThreadedHmacHelper(ILogger logger)
+ {
+ this.hmacs = new ConcurrentDictionary<int, ThreadHmacs>();
+ this.rotateKeyTimer = new Timer(RotateKeys, null, CookieHmacRotationTimeout, CookieHmacRotationTimeout);
+ this.cryptoRandom = RandomNumberGenerator.Create();
+
+ this.logger = logger;
+ SetHmacKey();
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the current cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetCurrentCookieHmacsForThread()
+ {
+ return GetHmacsForThread().currentHmac;
+ }
+
+ /// <summary>
+ /// [ThreadSafe] Get the previous cookie hmac for the current thread.
+ /// </summary>
+ public HMAC GetPreviousCookieHmacsForThread()
+ {
+ return GetHmacsForThread().previousHmac;
+ }
+
+ public void Dispose()
+ {
+ ManualResetEvent signalRotateKeyTimerEnded = new ManualResetEvent(false);
+ this.rotateKeyTimer.Dispose(signalRotateKeyTimerEnded);
+ signalRotateKeyTimerEnded.WaitOne();
+ signalRotateKeyTimerEnded.Dispose();
+ signalRotateKeyTimerEnded = null;
+ this.rotateKeyTimer = null;
+
+ this.cryptoRandom.Dispose();
+ this.cryptoRandom = null;
+
+ foreach (var threadIdToHmac in this.hmacs)
+ {
+ ThreadHmacs threadHmacs = threadIdToHmac.Value;
+ threadHmacs.currentHmac?.Dispose();
+ threadHmacs.currentHmac = null;
+ threadHmacs.previousHmac?.Dispose();
+ threadHmacs.previousHmac = null;
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = null;
+ }
+
+ this.hmacs.Clear();
+ }
+
+ private ThreadHmacs GetHmacsForThread()
+ {
+ int threadId = Thread.CurrentThread.ManagedThreadId;
+
+ if (!this.hmacs.TryGetValue(threadId, out ThreadHmacs threadHmacs))
+ {
+ threadHmacs = CreateNewThreadHmacs();
+
+ if (!this.hmacs.TryAdd(threadId, threadHmacs))
+ {
+ this.logger.WriteError($"Cannot add threadHmacs for thread {threadId} during GetHmacsForThread! Should never happen!");
+ }
+ }
+
+ return threadHmacs;
+ }
+
+ /// <summary>
+ /// Rotates the hmacs of all active threads
+ /// </summary>
+ private void RotateKeys(object _)
+ {
+ SetHmacKey();
+
+ foreach (var threadIds in this.hmacs)
+ {
+ RotateKey(threadIds.Key);
+ }
+ }
+
+ /// <summary>
+ /// Rotate hmacs of single thread
+ /// </summary>
+ /// <param name="threadId">Managed thread Id of thread calling this method.</param>
+ private void RotateKey(int threadId)
+ {
+ ThreadHmacs threadHmacs;
+
+ if (!this.hmacs.TryGetValue(threadId, out threadHmacs))
+ {
+ this.logger.WriteError($"Cannot find thread {threadId} in hmacs during rotation! Should never happen!");
+ return;
+ }
+
+ // No thread should still have a reference to hmacToDispose, which should now have a lifetime of > 1 hour
+ threadHmacs.hmacToDispose?.Dispose();
+ threadHmacs.hmacToDispose = threadHmacs.previousHmac;
+ threadHmacs.previousHmac = threadHmacs.currentHmac;
+ threadHmacs.currentHmac = CreateNewCookieHMAC();
+ }
+
+ private ThreadHmacs CreateNewThreadHmacs()
+ {
+ return new ThreadHmacs
+ {
+ previousHmac = CreateNewCookieHMAC(),
+ currentHmac = CreateNewCookieHMAC()
+ };
+ }
+
+ /// <summary>
+ /// Create a new cookie HMAC signer
+ /// </summary>
+ private HMAC CreateNewCookieHMAC()
+ {
+ const string HMACProvider = "System.Security.Cryptography.HMACSHA1";
+ HMAC hmac = HMAC.Create(HMACProvider);
+ hmac.Key = this.currentHmacKey;
+ return hmac;
+ }
+
+ /// <summary>
+ /// Creates a new cryptographically secure random Hmac key
+ /// </summary>
+ private void SetHmacKey()
+ {
+ // MSDN recommends 64 bytes key for HMACSHA-1
+ byte[] newKey = new byte[64];
+ this.cryptoRandom.GetBytes(newKey);
+ this.currentHmacKey = newKey;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
new file mode 100644
index 0000000..f567252
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Dtls/X25519EcdheRsaSha256.cs
@@ -0,0 +1,202 @@
+using Hazel.Crypto;
+using System;
+using System.Diagnostics;
+using System.Security.Cryptography;
+
+namespace Hazel.Dtls
+{
+ /// <summary>
+ /// ECDHE_RSA_*_256 cipher suite
+ /// </summary>
+ public class X25519EcdheRsaSha256 : IHandshakeCipherSuite
+ {
+ private readonly ByteSpan privateAgreementKey;
+ private SHA256 sha256 = SHA256.Create();
+
+ /// <summary>
+ /// Create a new instance of the x25519 key exchange
+ /// </summary>
+ /// <param name="random">Random data source</param>
+ public X25519EcdheRsaSha256(RandomNumberGenerator random)
+ {
+ byte[] buffer = new byte[X25519.KeySize];
+ random.GetBytes(buffer);
+ this.privateAgreementKey = buffer;
+ }
+
+ /// <inheritdoc />
+ public void Dispose()
+ {
+ this.sha256?.Dispose();
+ this.sha256 = null;
+ }
+
+ /// <inheritdoc />
+ public int SharedKeySize()
+ {
+ return X25519.KeySize;
+ }
+
+ /// <summary>
+ /// Calculate the server message size given an RSA key size
+ /// </summary>
+ /// <param name="keySize">
+ /// Size of the private key (in bits)
+ /// </param>
+ /// <returns>
+ /// Size of the ServerKeyExchange message in bytes
+ /// </returns>
+ private static int CalculateServerMessageSize(int keySize)
+ {
+ int signatureSize = keySize / 8;
+
+ return 0
+ + 1 // ECCurveType ServerKeyExchange.params.curve_params.curve_type
+ + 2 // NamedCurve ServerKeyExchange.params.curve_params.namedcurve
+ + 1 + X25519.KeySize // ECPoint ServerKeyExchange.params.public
+ + 1 // HashAlgorithm ServerKeyExchange.algorithm.hash
+ + 1 // SignatureAlgorithm ServerKeyExchange.signed_params.algorithm.signature
+ + 2 // ServerKeyExchange.signed_params.size
+ + signatureSize // ServerKeyExchange.signed_params.opaque
+ ;
+ }
+
+ /// <inheritdoc />
+ public int CalculateServerMessageSize(object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ return CalculateServerMessageSize(rsaPrivateKey.KeySize);
+ }
+
+ /// <inheritdoc />
+ public void EncodeServerKeyExchangeMessage(ByteSpan output, object privateKey)
+ {
+ RSA rsaPrivateKey = privateKey as RSA;
+ if (rsaPrivateKey == null)
+ {
+ throw new ArgumentException("Invalid private key", nameof(privateKey));
+ }
+
+ output[0] = (byte)ECCurveType.NamedCurve;
+ output.WriteBigEndian16((ushort)NamedCurve.x25519, 1);
+ output[3] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(4, X25519.KeySize), this.privateAgreementKey);
+
+ // Hash the key parameters
+ byte[] paramterDigest = this.sha256.ComputeHash(output.GetUnderlyingArray(), output.Offset, 4 + X25519.KeySize);
+
+ // Sign the paramter digest
+ RSAPKCS1SignatureFormatter signer = new RSAPKCS1SignatureFormatter(rsaPrivateKey);
+ signer.SetHashAlgorithm("SHA256");
+ ByteSpan signature = signer.CreateSignature(paramterDigest);
+
+ Debug.Assert(signature.Length == rsaPrivateKey.KeySize/8);
+ output[4 + X25519.KeySize] = (byte)HashAlgorithm.Sha256;
+ output[5 + X25519.KeySize] = (byte)SignatureAlgorithm.RSA;
+ output.Slice(6+X25519.KeySize).WriteBigEndian16((ushort)signature.Length);
+ signature.CopyTo(output.Slice(8+X25519.KeySize));
+ }
+
+ /// <inheritdoc />
+ public bool VerifyServerMessageAndGenerateSharedKey(ByteSpan output, ByteSpan serverKeyExchangeMessage, object publicKey)
+ {
+ RSA rsaPublicKey = publicKey as RSA;
+ if (rsaPublicKey == null)
+ {
+ return false;
+ }
+ else if (output.Length != X25519.KeySize)
+ {
+ return false;
+ }
+
+ // Verify message is compatible with this cipher suite
+ if (serverKeyExchangeMessage.Length != CalculateServerMessageSize(rsaPublicKey.KeySize))
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[0] != (byte)ECCurveType.NamedCurve)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage.ReadBigEndian16(1) != (ushort)NamedCurve.x25519)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[3] != X25519.KeySize)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[4 + X25519.KeySize] != (byte)HashAlgorithm.Sha256)
+ {
+ return false;
+ }
+ else if (serverKeyExchangeMessage[5 + X25519.KeySize] != (byte)SignatureAlgorithm.RSA)
+ {
+ return false;
+ }
+
+ ByteSpan keyParameters = serverKeyExchangeMessage.Slice(0, 4+X25519.KeySize);
+ ByteSpan othersPublicKey = keyParameters.Slice(4);
+ ushort signatureSize = serverKeyExchangeMessage.ReadBigEndian16(6 + X25519.KeySize);
+ ByteSpan signature = serverKeyExchangeMessage.Slice(4+keyParameters.Length);
+
+ if (signatureSize != signature.Length)
+ {
+ return false;
+ }
+
+ // Hash the key parameters
+ byte[] parameterDigest = this.sha256.ComputeHash(keyParameters.GetUnderlyingArray(), keyParameters.Offset, keyParameters.Length);
+
+ // Verify the signature
+ RSAPKCS1SignatureDeformatter verifier = new RSAPKCS1SignatureDeformatter(rsaPublicKey);
+ verifier.SetHashAlgorithm("SHA256");
+ if (!verifier.VerifySignature(parameterDigest, signature.ToArray()))
+ {
+ return false;
+ }
+
+ // Signature has been validated, generate the shared key
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+
+ private static int ClientMessageSize = 0
+ + 1 + X25519.KeySize // ECPoint ClientKeyExchange.ecdh_Yc
+ ;
+
+ /// <inheritdoc />
+ public int CalculateClientMessageSize()
+ {
+ return ClientMessageSize;
+ }
+
+ /// <inheritdoc />
+ public void EncodeClientKeyExchangeMessage(ByteSpan output)
+ {
+ output[0] = (byte)X25519.KeySize;
+ X25519.Func(output.Slice(1, X25519.KeySize), this.privateAgreementKey);
+ }
+
+ /// <inheritdoc />
+ public bool VerifyClientMessageAndGenerateSharedKey(ByteSpan output, ByteSpan clientKeyExchangeMessage)
+ {
+ if (clientKeyExchangeMessage.Length != ClientMessageSize)
+ {
+ return false;
+ }
+ else if (clientKeyExchangeMessage[0] != (byte)X25519.KeySize)
+ {
+ return false;
+ }
+
+ ByteSpan othersPublicKey = clientKeyExchangeMessage.Slice(1);
+ return X25519.Func(output, this.privateAgreementKey, othersPublicKey);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Extensions.cs b/Tools/Hazel-Networking/Hazel/Extensions.cs
new file mode 100644
index 0000000..dd3a1bc
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Extensions.cs
@@ -0,0 +1,34 @@
+using System.Collections.Generic;
+
+namespace Hazel
+{
+ public static class Extensions
+ {
+ public static void Swap<T>(this IList<T> self, int idx0, int idx1)
+ {
+ var temp = self[idx0];
+ self[idx0] = self[idx1];
+ self[idx1] = temp;
+ }
+
+ public static int ClampToInt(this float value, int min, int max)
+ {
+ int output = (int)value;
+ if (output < min) output = min;
+ else if (output > max) output = max;
+ return output;
+ }
+
+ public static bool TryDequeue<T>(this Queue<T> self, out T item)
+ {
+ if (self.Count > 0)
+ {
+ item = self.Dequeue();
+ return true;
+ }
+
+ item = default;
+ return false;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs
new file mode 100644
index 0000000..fb36b00
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/HazelThreadPool.cs
@@ -0,0 +1,44 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Hazel
+{
+ internal class HazelThreadPool
+ {
+ private Thread[] threads;
+
+ public HazelThreadPool(int numThreads, ThreadStart action)
+ {
+ this.threads = new Thread[numThreads];
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ this.threads[i] = new Thread(action);
+ }
+ }
+
+ public void Start()
+ {
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ this.threads[i].Start();
+ }
+ }
+
+ public void Join()
+ {
+ for (int i = 0; i < this.threads.Length; ++i)
+ {
+ var thread = this.threads[i];
+ try
+ {
+ thread.Join();
+ }
+ catch { }
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs
new file mode 100644
index 0000000..e37be45
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpConnectionListener.cs
@@ -0,0 +1,402 @@
+using System;
+using System.Collections.Concurrent;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.Udp.FewerThreads
+{
+ /// <summary>
+ /// Listens for new UDP connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class ThreadLimitedUdpConnectionListener : NetworkConnectionListener
+ {
+ private struct SendMessageInfo
+ {
+ public ByteSpan Span;
+ public IPEndPoint Recipient;
+ }
+
+ private struct ReceiveMessageInfo
+ {
+ public MessageReader Message;
+ public IPEndPoint Sender;
+ public ConnectionId ConnectionId;
+ }
+
+ private const int SendReceiveBufferSize = 1024 * 1024;
+
+ private Socket socket;
+ protected ILogger Logger;
+
+ private Thread reliablePacketThread;
+ private Thread receiveThread;
+ private Thread sendThread;
+ private HazelThreadPool processThreads;
+
+ public bool ReceiveThreadRunning => this.receiveThread.ThreadState == ThreadState.Running;
+
+ public struct ConnectionId : IEquatable<ConnectionId>
+ {
+ public IPEndPoint EndPoint;
+ public int Serial;
+
+ public static ConnectionId Create(IPEndPoint endPoint, int serial)
+ {
+ return new ConnectionId{
+ EndPoint = endPoint,
+ Serial = serial,
+ };
+ }
+
+ public bool Equals(ConnectionId other)
+ {
+ return this.Serial == other.Serial
+ && this.EndPoint.Equals(other.EndPoint)
+ ;
+ }
+
+ public override bool Equals(object obj)
+ {
+ if (obj is ConnectionId)
+ {
+ return this.Equals((ConnectionId)obj);
+ }
+
+ return false;
+ }
+
+ public override int GetHashCode()
+ {
+ ///NOTE(mendsley): We're only hashing the endpoint
+ /// here, as the common case will have one
+ /// connection per address+port tuple.
+ return this.EndPoint.GetHashCode();
+ }
+ }
+
+ protected ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection> allConnections = new ConcurrentDictionary<ConnectionId, ThreadLimitedUdpServerConnection>();
+
+ private BlockingCollection<ReceiveMessageInfo> receiveQueue;
+ private BlockingCollection<SendMessageInfo> sendQueue = new BlockingCollection<SendMessageInfo>();
+
+ public int MaxAge
+ {
+ get
+ {
+ var now = DateTime.UtcNow;
+ TimeSpan max = new TimeSpan();
+ foreach (var con in allConnections.Values)
+ {
+ var val = now - con.CreationTime;
+ if (val > max) max = val;
+ }
+
+ return (int)max.TotalSeconds;
+ }
+ }
+
+ public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count;
+ public override int ConnectionCount { get { return this.allConnections.Count; } }
+ public override int SendQueueLength { get { return this.sendQueue.Count; } }
+ public override int ReceiveQueueLength { get { return this.receiveQueue.Count; } }
+
+ private bool isActive;
+
+ public ThreadLimitedUdpConnectionListener(int numWorkers, IPEndPoint endPoint, ILogger logger, IPMode ipMode = IPMode.IPv4)
+ {
+ this.Logger = logger;
+ this.EndPoint = endPoint;
+ this.IPMode = ipMode;
+
+ this.receiveQueue = new BlockingCollection<ReceiveMessageInfo>(10000);
+
+ this.socket = UdpConnection.CreateSocket(this.IPMode);
+ this.socket.ExclusiveAddressUse = true;
+ this.socket.Blocking = false;
+
+ this.socket.ReceiveBufferSize = SendReceiveBufferSize;
+ this.socket.SendBufferSize = SendReceiveBufferSize;
+
+ this.reliablePacketThread = new Thread(ManageReliablePackets);
+ this.sendThread = new Thread(SendLoop);
+ this.receiveThread = new Thread(ReceiveLoop);
+ this.processThreads = new HazelThreadPool(numWorkers, ProcessingLoop);
+ }
+
+ ~ThreadLimitedUdpConnectionListener()
+ {
+ this.Dispose(false);
+ }
+
+ // This is just for booting people after they've been connected a certain amount of time...
+ public void DisconnectOldConnections(TimeSpan maxAge, MessageWriter disconnectMessage)
+ {
+ var now = DateTime.UtcNow;
+ foreach (var conn in this.allConnections.Values)
+ {
+ if (now - conn.CreationTime > maxAge)
+ {
+ conn.Disconnect("Stale Connection", disconnectMessage);
+ }
+ }
+ }
+
+ private void ManageReliablePackets()
+ {
+ while (this.isActive)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ var sock = kvp.Value;
+ sock.ManageReliablePackets();
+ }
+
+ Thread.Sleep(100);
+ }
+ }
+
+ public override void Start()
+ {
+ try
+ {
+ socket.Bind(EndPoint);
+ }
+ catch (SocketException e)
+ {
+ throw new HazelException("Could not start listening as a SocketException occurred", e);
+ }
+
+ this.isActive = true;
+ this.reliablePacketThread.Start();
+ this.sendThread.Start();
+ this.receiveThread.Start();
+ this.processThreads.Start();
+ }
+
+ private void ReceiveLoop()
+ {
+ while (this.isActive)
+ {
+ if (this.socket.Poll(1000, SelectMode.SelectRead))
+ {
+ if (!isActive) break;
+
+ EndPoint remoteEP = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port);
+ var message = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ message.Length = socket.ReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP);
+ }
+ catch (SocketException sx)
+ {
+ message.Recycle();
+ if (sx.SocketErrorCode == SocketError.NotConnected)
+ {
+ this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected);
+ return;
+ }
+
+ this.Logger.WriteError("Socket Ex in ReceiveLoop: " + sx.Message);
+ continue;
+ }
+ catch (Exception ex)
+ {
+ message.Recycle();
+ this.Logger.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+
+ ConnectionId connectionId = ConnectionId.Create((IPEndPoint)remoteEP, 0);
+ this.ProcessIncomingMessageFromOtherThread(message, (IPEndPoint)remoteEP, connectionId);
+ }
+ }
+ }
+
+ private void ProcessingLoop()
+ {
+ foreach (ReceiveMessageInfo msg in this.receiveQueue.GetConsumingEnumerable())
+ {
+ try
+ {
+ this.ReadCallback(msg.Message, msg.Sender, msg.ConnectionId);
+ }
+ catch
+ {
+
+ }
+ }
+ }
+
+ protected void ProcessIncomingMessageFromOtherThread(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId)
+ {
+ var info = new ReceiveMessageInfo() { Message = message, Sender = remoteEndPoint, ConnectionId = connectionId };
+ if (!this.receiveQueue.TryAdd(info))
+ {
+ this.Statistics.AddReceiveThreadBlocking();
+ this.receiveQueue.Add(info);
+ }
+ }
+
+ private void SendLoop()
+ {
+ foreach (SendMessageInfo msg in this.sendQueue.GetConsumingEnumerable())
+ {
+ try
+ {
+ if (this.socket.Poll(Timeout.Infinite, SelectMode.SelectWrite))
+ {
+ this.socket.SendTo(msg.Span.GetUnderlyingArray(), msg.Span.Offset, msg.Span.Length, SocketFlags.None, msg.Recipient);
+ this.Statistics.AddBytesSent(msg.Span.Length - msg.Span.Offset);
+ }
+ else
+ {
+ this.Logger.WriteError("Socket is no longer able to send");
+ break;
+ }
+ }
+ catch (Exception e)
+ {
+ this.Logger.WriteError("Error in loop while sending: " + e.Message);
+ Thread.Sleep(1);
+ }
+ }
+ }
+
+ protected virtual void ReadCallback(MessageReader message, IPEndPoint remoteEndPoint, ConnectionId connectionId)
+ {
+ int bytesReceived = message.Length;
+ bool aware = true;
+ bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello;
+
+ // If we're aware of this connection use the one already
+ // If this is a new client then connect with them!
+ ThreadLimitedUdpServerConnection connection;
+ if (!this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ lock (this.allConnections)
+ {
+ if (!this.allConnections.TryGetValue(connectionId, out connection))
+ {
+ // Check for malformed connection attempts
+ if (!isHello)
+ {
+ message.Recycle();
+ return;
+ }
+
+ if (AcceptConnection != null)
+ {
+ if (!AcceptConnection(remoteEndPoint, message.Buffer, out var response))
+ {
+ message.Recycle();
+ if (response != null)
+ {
+ SendDataRaw(response, remoteEndPoint);
+ }
+
+ return;
+ }
+ }
+
+ aware = false;
+ connection = new ThreadLimitedUdpServerConnection(this, connectionId, remoteEndPoint, this.IPMode, this.Logger);
+ if (!this.allConnections.TryAdd(connectionId, connection))
+ {
+ throw new HazelException("Failed to add a connection. This should never happen.");
+ }
+ }
+ }
+ }
+
+ // If it's a new connection invoke the NewConnection event.
+ // This needs to happen before handling the message because in localhost scenarios, the ACK and
+ // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers
+ if (!aware)
+ {
+ // Skip header and hello byte;
+ message.Offset = 4;
+ message.Length = bytesReceived - 4;
+ message.Position = 0;
+ try
+ {
+ this.InvokeNewConnection(message, connection);
+ }
+ catch (Exception e)
+ {
+ this.Logger.WriteError("NewConnection handler threw: " + e);
+ }
+ }
+
+ // Inform the connection of the buffer (new connections need to send an ack back to client)
+ connection.HandleReceive(message, bytesReceived);
+ }
+
+ internal void SendDataRaw(byte[] response, IPEndPoint remoteEndPoint)
+ {
+ QueueRawData(response, remoteEndPoint);
+ }
+
+ protected virtual void QueueRawData(ByteSpan span, IPEndPoint remoteEndPoint)
+ {
+ this.sendQueue.TryAdd(new SendMessageInfo() { Span = span, Recipient = remoteEndPoint });
+ }
+
+ /// <summary>
+ /// Removes a virtual connection from the list.
+ /// </summary>
+ /// <param name="endPoint">Connection key of the virtual connection.</param>
+ internal bool RemoveConnectionTo(ConnectionId connectionId)
+ {
+ return this.allConnections.TryRemove(connectionId, out _);
+ }
+
+ /// <summary>
+ /// This is after all messages could be sent. Clean up anything extra.
+ /// </summary>
+ internal virtual void RemovePeerRecord(ConnectionId connectionId)
+ {
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ kvp.Value.Dispose();
+ }
+
+ bool wasActive = this.isActive;
+ this.isActive = false;
+
+ // Flush outgoing packets
+ this.sendQueue?.CompleteAdding();
+
+ if (wasActive)
+ {
+ this.sendThread.Join();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.receiveQueue?.CompleteAdding();
+
+ if (wasActive)
+ {
+ this.reliablePacketThread.Join();
+ this.receiveThread.Join();
+ this.processThreads.Join();
+ }
+
+ this.receiveQueue?.Dispose();
+ this.receiveQueue = null;
+ this.sendQueue?.Dispose();
+ this.sendQueue = null;
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs
new file mode 100644
index 0000000..bb139c7
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/FewerThreads/ThreadLimitedUdpServerConnection.cs
@@ -0,0 +1,110 @@
+using System;
+using System.Net;
+
+namespace Hazel.Udp.FewerThreads
+{
+ /// <summary>
+ /// Represents a servers's connection to a client that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ public sealed class ThreadLimitedUdpServerConnection : UdpConnection
+ {
+ public readonly DateTime CreationTime = DateTime.UtcNow;
+
+ /// <summary>
+ /// The connection listener that we use the socket of.
+ /// </summary>
+ /// <remarks>
+ /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that
+ /// created this connection and is hence the listener this conenction sends and receives via.
+ /// </remarks>
+ public ThreadLimitedUdpConnectionListener Listener { get; private set; }
+
+ public ThreadLimitedUdpConnectionListener.ConnectionId ConnectionId { get; private set; }
+
+ /// <summary>
+ /// Creates a UdpConnection for the virtual connection to the endpoint.
+ /// </summary>
+ /// <param name="listener">The listener that created this connection.</param>
+ /// <param name="endPoint">The endpoint that we are connected to.</param>
+ /// <param name="IPMode">The IPMode we are connected using.</param>
+ internal ThreadLimitedUdpServerConnection(ThreadLimitedUdpConnectionListener listener, ThreadLimitedUdpConnectionListener.ConnectionId connectionId, IPEndPoint endPoint, IPMode IPMode, ILogger logger)
+ : base(logger)
+ {
+ this.Listener = listener;
+ this.ConnectionId = connectionId;
+ this.EndPoint = endPoint;
+ this.IPMode = IPMode;
+
+ State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ if (bytes.Length != length) throw new ArgumentException("I made an assumption here. I hope you see this error.");
+
+ // Hrm, well this is inaccurate for DTLS connections because the Listener does the encryption which may change the size.
+ // but I don't want to have a bunch of client references in the send queue...
+ // Does this perhaps mean the encryption is being done in the wrong class?
+ this.Statistics.LogPacketSend(length);
+ Listener.SendDataRaw(bytes, EndPoint);
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ if (!Listener.RemoveConnectionTo(this.ConnectionId)) return false;
+ this._state = ConnectionState.NotConnected;
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ this.WriteBytesToConnection(bytes, bytes.Length);
+ }
+ catch { }
+
+ return true;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ Listener.RemovePeerRecord(this.ConnectionId);
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Hazel.csproj b/Tools/Hazel-Networking/Hazel/Hazel.csproj
new file mode 100644
index 0000000..3a7ea17
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Hazel.csproj
@@ -0,0 +1,14 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <TargetFrameworks>netstandard2.0;net472</TargetFrameworks>
+ <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <AssemblyAttribute Include="System.Runtime.CompilerServices.InternalsVisibleToAttribute">
+ <_Parameter1>Hazel.UnitTests</_Parameter1>
+ </AssemblyAttribute>
+ </ItemGroup>
+
+</Project>
diff --git a/Tools/Hazel-Networking/Hazel/HazelException.cs b/Tools/Hazel-Networking/Hazel/HazelException.cs
new file mode 100644
index 0000000..c0db05a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/HazelException.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Wrapper for exceptions thrown from Hazel.
+ /// </summary>
+ [Serializable]
+ public class HazelException : Exception
+ {
+ internal HazelException(string msg) : base (msg)
+ {
+
+ }
+
+ internal HazelException(string msg, Exception e) : base (msg, e)
+ {
+
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/IPMode.cs b/Tools/Hazel-Networking/Hazel/IPMode.cs
new file mode 100644
index 0000000..04c8c38
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/IPMode.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Represents the IP version that a connection or listener will use.
+ /// </summary>
+ /// <remarks>
+ /// If you wand a client to connect or be able to connect using IPv6 then you should use <see cref="IPv4AndIPv6"/>,
+ /// this sets the underlying sockets to use IPv6 but still allow IPv4 sockets to connect for backwards compatability
+ /// and hence it is the default IPMode in most cases.
+ /// </remarks>
+ public enum IPMode
+ {
+ /// <summary>
+ /// Instruction to use IPv4 only, IPv6 connections will not be able to connect.
+ /// </summary>
+ IPv4,
+
+ /// <summary>
+ /// Instruction to use IPv6 only, IPv4 connections will not be able to connect. IPv4 addresses can be connected
+ /// by converting to IPv6 addresses.
+ /// </summary>
+ IPv6
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/IRecyclable.cs b/Tools/Hazel-Networking/Hazel/IRecyclable.cs
new file mode 100644
index 0000000..3e9769e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/IRecyclable.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Interface for all items that can be returned to an object pool.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public interface IRecyclable
+ {
+ /// <summary>
+ /// Returns this object back to the object pool.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Calling this when you are done with the object returns the object back to a pool in order to be reused.
+ /// This can reduce the amount of work the GC has to do dramatically but it is optional to call this.
+ /// </para>
+ /// <para>
+ /// Calling this indicates to Hazel that this can be reused and thus you should only call this when you are
+ /// completely finished with the object as the contents can be overwritten at any point after.
+ /// </para>
+ /// </remarks>
+ void Recycle();
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs
new file mode 100644
index 0000000..428c567
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ListenerStatistics.cs
@@ -0,0 +1,23 @@
+using System.Threading;
+
+namespace Hazel
+{
+ public class ListenerStatistics
+ {
+ private int _receiveThreadBlocked;
+ public int ReceiveThreadBlocked => this._receiveThreadBlocked;
+
+ private long _bytesSent;
+ public long BytesSent => this._bytesSent;
+
+ internal void AddReceiveThreadBlocking()
+ {
+ Interlocked.Increment(ref _receiveThreadBlocked);
+ }
+
+ internal void AddBytesSent(long bytes)
+ {
+ Interlocked.Add(ref _bytesSent, bytes);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/MessageReader.cs b/Tools/Hazel-Networking/Hazel/MessageReader.cs
new file mode 100644
index 0000000..bd3b0d8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/MessageReader.cs
@@ -0,0 +1,452 @@
+using System;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+namespace Hazel
+{
+ public class MessageReader : IRecyclable
+ {
+ public static readonly ObjectPool<MessageReader> ReaderPool = new ObjectPool<MessageReader>(() => new MessageReader());
+
+ public byte[] Buffer;
+ public byte Tag;
+
+ public int Length; // 总长度
+ public int Offset; // length和tag后面
+
+ public int BytesRemaining => this.Length - this.Position;
+
+ private MessageReader Parent;
+
+ public int Position
+ {
+ get { return this._position; }
+ set
+ {
+ this._position = value;
+ this.readHead = value + Offset;
+ }
+ }
+
+ private int _position;
+ private int readHead;
+
+ public static MessageReader GetSized(int minSize)
+ {
+ var output = ReaderPool.GetObject();
+
+ if (output.Buffer == null || output.Buffer.Length < minSize)
+ {
+ output.Buffer = new byte[minSize];
+ }
+ else
+ {
+ Array.Clear(output.Buffer, 0, output.Buffer.Length);
+ }
+
+ output.Offset = 0;
+ output.Position = 0;
+ output.Tag = byte.MaxValue;
+ return output;
+ }
+
+ public static MessageReader Get(byte[] buffer)
+ {
+ var output = ReaderPool.GetObject();
+
+ output.Buffer = buffer;
+ output.Offset = 0;
+ output.Position = 0;
+ output.Length = buffer.Length;
+ output.Tag = byte.MaxValue;
+
+ return output;
+ }
+
+ public static MessageReader CopyMessageIntoParent(MessageReader source)
+ {
+ var output = MessageReader.GetSized(source.Length + 3);
+ System.Buffer.BlockCopy(source.Buffer, source.Offset - 3, output.Buffer, 0, source.Length + 3);
+
+ output.Offset = 0;
+ output.Position = 0;
+ output.Length = source.Length + 3;
+
+ return output;
+ }
+
+ public static MessageReader Get(MessageReader source)
+ {
+ var output = MessageReader.GetSized(source.Buffer.Length);
+ System.Buffer.BlockCopy(source.Buffer, 0, output.Buffer, 0, source.Buffer.Length);
+
+ output.Offset = source.Offset;
+
+ output._position = source._position;
+ output.readHead = source.readHead;
+
+ output.Length = source.Length;
+ output.Tag = source.Tag;
+
+ return output;
+ }
+
+ public static MessageReader Get(byte[] buffer, int offset)
+ {
+ // Ensure there is at least a header
+ if (offset + 3 > buffer.Length) return null;
+
+ var output = ReaderPool.GetObject();
+
+ output.Buffer = buffer;
+ output.Offset = offset;
+ output.Position = 0;
+
+ output.Length = output.ReadUInt16();
+ output.Tag = output.ReadByte();
+
+ output.Offset += 3;
+ output.Position = 0;
+
+ return output;
+ }
+
+ /// <summary>
+ /// Produces a MessageReader using the parent's buffer. This MessageReader should **NOT** be recycled.
+ /// </summary>
+ public MessageReader ReadMessage()
+ {
+ // Ensure there is at least a header
+ if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+
+ var output = new MessageReader();
+
+ output.Parent = this;
+ output.Buffer = this.Buffer;
+ output.Offset = this.readHead;
+ output.Position = 0;
+
+ output.Length = output.ReadUInt16();
+ output.Tag = output.ReadByte();
+
+ output.Offset += 3;
+ output.Position = 0;
+
+ if (this.BytesRemaining < output.Length + 3) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {output.Length + 3} of {this.BytesRemaining}");
+
+ this.Position += output.Length + 3;
+ return output;
+ }
+
+ /// <summary>
+ /// Produces a MessageReader with a new buffer. This MessageReader should be recycled.
+ /// </summary>
+ public MessageReader ReadMessageAsNewBuffer()
+ {
+ if (this.BytesRemaining < 3) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+
+ var len = this.ReadUInt16();
+ var tag = this.ReadByte();
+
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Message Length at Position {this.readHead} is longer than message length: {len} of {this.BytesRemaining}");
+
+ var output = MessageReader.GetSized(len);
+
+ Array.Copy(this.Buffer, this.readHead, output.Buffer, 0, len);
+
+ output.Length = len;
+ output.Tag = tag;
+
+ this.Position += output.Length;
+ return output;
+ }
+
+ public MessageWriter StartWriter()
+ {
+ var output = new MessageWriter(this.Buffer);
+ output.Position = this.readHead;
+ return output;
+ }
+
+ public MessageReader Duplicate()
+ {
+ var output = GetSized(this.Length);
+ Array.Copy(this.Buffer, this.Offset, output.Buffer, 0, this.Length);
+ output.Length = this.Length;
+ output.Offset = 0;
+ output.Position = 0;
+
+ return output;
+ }
+
+ public void RemoveMessage(MessageReader reader)
+ {
+ var temp = MessageReader.GetSized(reader.Buffer.Length);
+ try
+ {
+ var headerOffset = reader.Offset - 3;
+ var endOfMessage = reader.Offset + reader.Length;
+ var len = reader.Buffer.Length - endOfMessage;
+
+ Array.Copy(reader.Buffer, endOfMessage, temp.Buffer, 0, len);
+ Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset, len);
+
+ this.AdjustLength(reader.Offset, reader.Length + 3);
+ }
+ finally
+ {
+ temp.Recycle();
+ }
+ }
+
+ public void InsertMessage(MessageReader reader, MessageWriter writer)
+ {
+ var temp = MessageReader.GetSized(reader.Buffer.Length);
+ try
+ {
+ var headerOffset = reader.Offset - 3;
+ var startOfMessage = reader.Offset;
+ var len = reader.Buffer.Length - startOfMessage;
+ int writerOffset = 3;
+ switch (writer.SendOption)
+ {
+ case SendOption.Reliable:
+ writerOffset = 3;
+ break;
+ case SendOption.None:
+ writerOffset = 1;
+ break;
+ }
+
+ //store the original buffer in temp
+ Array.Copy(reader.Buffer, headerOffset, temp.Buffer, 0, len);
+
+ //put the contents of writer in at headerOffset
+ Array.Copy(writer.Buffer, writerOffset, this.Buffer, headerOffset, writer.Length-writerOffset);
+
+ //put the original buffer in after that
+ Array.Copy(temp.Buffer, 0, this.Buffer, headerOffset + (writer.Length-writerOffset), len - writer.Length);
+
+ this.AdjustLength(-1 * reader.Offset , -1 * (writer.Length - writerOffset));
+ }
+ finally
+ {
+ temp.Recycle();
+ }
+ }
+
+ private void AdjustLength(int offset, int amount)
+ {
+ if (this.readHead > offset)
+ {
+ this.Position -= amount;
+ }
+
+ if (Parent != null)
+ {
+ var lengthOffset = this.Offset - 3;
+ var curLen = this.Buffer[lengthOffset]
+ | (this.Buffer[lengthOffset + 1] << 8);
+
+ curLen -= amount;
+ this.Length -= amount;
+
+ this.Buffer[lengthOffset] = (byte)curLen;
+ this.Buffer[lengthOffset + 1] = (byte)(this.Buffer[lengthOffset + 1] >> 8);
+
+ Parent.AdjustLength(offset, amount);
+ }
+ }
+
+ public void Recycle()
+ {
+ this.Parent = null;
+ ReaderPool.PutObject(this);
+ }
+
+ #region Read Methods
+ public bool ReadBoolean()
+ {
+ byte val = this.FastByte();
+ return val != 0;
+ }
+
+ public sbyte ReadSByte()
+ {
+ return (sbyte)this.FastByte();
+ }
+
+ public byte ReadByte()
+ {
+ return this.FastByte();
+ }
+
+ public ushort ReadUInt16()
+ {
+ ushort output =
+ (ushort)(this.FastByte()
+ | this.FastByte() << 8);
+ return output;
+ }
+
+ public short ReadInt16()
+ {
+ short output =
+ (short)(this.FastByte()
+ | this.FastByte() << 8);
+ return output;
+ }
+
+ public uint ReadUInt32()
+ {
+ uint output = this.FastByte()
+ | (uint)this.FastByte() << 8
+ | (uint)this.FastByte() << 16
+ | (uint)this.FastByte() << 24;
+
+ return output;
+ }
+
+ public int ReadInt32()
+ {
+ int output = this.FastByte()
+ | this.FastByte() << 8
+ | this.FastByte() << 16
+ | this.FastByte() << 24;
+
+ return output;
+ }
+
+ public ulong ReadUInt64()
+ {
+ ulong output = (ulong)this.FastByte()
+ | (ulong)this.FastByte() << 8
+ | (ulong)this.FastByte() << 16
+ | (ulong)this.FastByte() << 24
+ | (ulong)this.FastByte() << 32
+ | (ulong)this.FastByte() << 40
+ | (ulong)this.FastByte() << 48
+ | (ulong)this.FastByte() << 56;
+
+ return output;
+ }
+
+ public long ReadInt64()
+ {
+ long output = (long)this.FastByte()
+ | (long)this.FastByte() << 8
+ | (long)this.FastByte() << 16
+ | (long)this.FastByte() << 24
+ | (long)this.FastByte() << 32
+ | (long)this.FastByte() << 40
+ | (long)this.FastByte() << 48
+ | (long)this.FastByte() << 56;
+
+ return output;
+ }
+
+ public unsafe float ReadSingle()
+ {
+ float output = 0;
+ fixed (byte* bufPtr = &this.Buffer[this.readHead])
+ {
+ byte* outPtr = (byte*)&output;
+
+ *outPtr = *bufPtr;
+ *(outPtr + 1) = *(bufPtr + 1);
+ *(outPtr + 2) = *(bufPtr + 2);
+ *(outPtr + 3) = *(bufPtr + 3);
+ }
+
+ this.Position += 4;
+ return output;
+ }
+
+ public string ReadString()
+ {
+ int len = this.ReadPackedInt32();
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}");
+
+ string output = UTF8Encoding.UTF8.GetString(this.Buffer, this.readHead, len);
+
+ this.Position += len;
+ return output;
+ }
+
+ public byte[] ReadBytesAndSize()
+ {
+ int len = this.ReadPackedInt32();
+ if (this.BytesRemaining < len) throw new InvalidDataException($"Read length is longer than message length: {len} of {this.BytesRemaining}");
+
+ return this.ReadBytes(len);
+ }
+
+ public byte[] ReadBytes(int length)
+ {
+ if (this.BytesRemaining < length) throw new InvalidDataException($"Read length is longer than message length: {length} of {this.BytesRemaining}");
+
+ byte[] output = new byte[length];
+ Array.Copy(this.Buffer, this.readHead, output, 0, output.Length);
+ this.Position += output.Length;
+ return output;
+ }
+
+ ///
+ public int ReadPackedInt32()
+ {
+ return (int)this.ReadPackedUInt32();
+ }
+
+ ///
+ public uint ReadPackedUInt32()
+ {
+ bool readMore = true;
+ int shift = 0;
+ uint output = 0;
+
+ while (readMore)
+ {
+ if (this.BytesRemaining < 1) throw new InvalidDataException($"Read length is longer than message length.");
+
+ byte b = this.ReadByte();
+ if (b >= 0x80)
+ {
+ readMore = true;
+ b ^= 0x80;
+ }
+ else
+ {
+ readMore = false;
+ }
+
+ output |= (uint)(b << shift);
+ shift += 7;
+ }
+
+ return output;
+ }
+ #endregion
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private byte FastByte()
+ {
+ this._position++;
+ return this.Buffer[this.readHead++];
+ }
+
+ public unsafe static bool IsLittleEndian()
+ {
+ byte b;
+ unsafe
+ {
+ int i = 1;
+ byte* bp = (byte*)&i;
+ b = *bp;
+ }
+
+ return b == 1;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/MessageWriter.cs b/Tools/Hazel-Networking/Hazel/MessageWriter.cs
new file mode 100644
index 0000000..9caaaf2
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/MessageWriter.cs
@@ -0,0 +1,365 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// 嵌套结构的Message
+ /// 结构:
+ /// ------------------------------------
+ /// 2bytes (ushort) 包长度
+ /// 1bytes (tag) 协议ID,在AmongUS里是tags.cs里定义的tag和subtag
+ /// ------------------------------------
+ /// 数据 包括嵌套的子协议
+ /// ------------------------------------
+ /// </summary>
+ public class MessageWriter : IRecyclable
+ {
+ public static int BufferSize = 64000;
+ public static readonly ObjectPool<MessageWriter> WriterPool = new ObjectPool<MessageWriter>(() => new MessageWriter(BufferSize));
+
+ public byte[] Buffer;
+ public int Length; // 总长度
+ public int Position; // 写入游标
+
+ public SendOption SendOption { get; private set; }
+
+ private Stack<int> messageStarts = new Stack<int>();
+
+ public MessageWriter(byte[] buffer)
+ {
+ this.Buffer = buffer;
+ this.Length = this.Buffer.Length;
+ }
+
+ ///
+ public MessageWriter(int bufferSize)
+ {
+ this.Buffer = new byte[bufferSize];
+ }
+
+ /// <summary>
+ /// 去掉header
+ /// </summary>
+ /// <param name="includeHeader"></param>
+ /// <returns></returns>
+ /// <exception cref="NotImplementedException"></exception>
+ public byte[] ToByteArray(bool includeHeader)
+ {
+ if (includeHeader)
+ {
+ byte[] output = new byte[this.Length];
+ System.Buffer.BlockCopy(this.Buffer, 0, output, 0, this.Length);
+ return output;
+ }
+ else
+ {
+ switch (this.SendOption)
+ {
+ case SendOption.Reliable:
+ {
+ byte[] output = new byte[this.Length - 3];
+ System.Buffer.BlockCopy(this.Buffer, 3, output, 0, this.Length - 3);
+ return output;
+ }
+ case SendOption.None:
+ {
+ byte[] output = new byte[this.Length - 1];
+ System.Buffer.BlockCopy(this.Buffer, 1, output, 0, this.Length - 1);
+ return output;
+ }
+ }
+ }
+
+ throw new NotImplementedException();
+ }
+
+ ///
+ /// <param name="sendOption">The option specifying how the message should be sent.</param>
+ public static MessageWriter Get(SendOption sendOption = SendOption.None)
+ {
+ var output = WriterPool.GetObject();
+ output.Clear(sendOption);
+
+ return output;
+ }
+
+ public bool HasBytes(int expected)
+ {
+ if (this.SendOption == SendOption.None)
+ {
+ return this.Length > 1 + expected;
+ }
+
+ return this.Length > 3 + expected;
+ }
+
+ ///
+ public void StartMessage(byte typeFlag)
+ {
+ var messageStart = this.Position;
+ messageStarts.Push(messageStart);
+ this.Buffer[messageStart] = 0;
+ this.Buffer[messageStart + 1] = 0;
+ this.Position += 2;
+ this.Write(typeFlag);
+ }
+
+ ///
+ public void EndMessage()
+ {
+ var lastMessageStart = messageStarts.Pop();
+ ushort length = (ushort)(this.Position - lastMessageStart - 3); // Minus length and type byte
+ this.Buffer[lastMessageStart] = (byte)length;
+ this.Buffer[lastMessageStart + 1] = (byte)(length >> 8);
+ }
+
+ ///
+ public void CancelMessage()
+ {
+ this.Position = this.messageStarts.Pop();
+ this.Length = this.Position;
+ }
+
+ public void Clear(SendOption sendOption)
+ {
+ Array.Clear(this.Buffer, 0, this.Buffer.Length);
+ this.messageStarts.Clear();
+ this.SendOption = sendOption;
+ this.Buffer[0] = (byte)sendOption;
+ switch (sendOption)
+ {
+ default:
+ case SendOption.None:
+ this.Length = this.Position = 1;
+ break;
+ case SendOption.Reliable:
+ this.Length = this.Position = 3;
+ break;
+ }
+ }
+
+ ///
+ public void Recycle()
+ {
+ this.Position = this.Length = 0;
+ WriterPool.PutObject(this);
+ }
+
+ #region WriteMethods
+
+ public void CopyFrom(MessageReader target)
+ {
+ int offset, length;
+ if (target.Tag == byte.MaxValue)
+ {
+ offset = target.Offset;
+ length = target.Length;
+ }
+ else
+ {
+ offset = target.Offset - 3;
+ length = target.Length + 3;
+ }
+
+ System.Buffer.BlockCopy(target.Buffer, offset, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(bool value)
+ {
+ this.Buffer[this.Position++] = (byte)(value ? 1 : 0);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(sbyte value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte value)
+ {
+ this.Buffer[this.Position++] = value;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(short value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(ushort value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(uint value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(int value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(ulong value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ this.Buffer[this.Position++] = (byte)(value >> 32);
+ this.Buffer[this.Position++] = (byte)(value >> 40);
+ this.Buffer[this.Position++] = (byte)(value >> 48);
+ this.Buffer[this.Position++] = (byte)(value >> 56);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(long value)
+ {
+ this.Buffer[this.Position++] = (byte)value;
+ this.Buffer[this.Position++] = (byte)(value >> 8);
+ this.Buffer[this.Position++] = (byte)(value >> 16);
+ this.Buffer[this.Position++] = (byte)(value >> 24);
+ this.Buffer[this.Position++] = (byte)(value >> 32);
+ this.Buffer[this.Position++] = (byte)(value >> 40);
+ this.Buffer[this.Position++] = (byte)(value >> 48);
+ this.Buffer[this.Position++] = (byte)(value >> 56);
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public unsafe void Write(float value)
+ {
+ fixed (byte* ptr = &this.Buffer[this.Position])
+ {
+ byte* valuePtr = (byte*)&value;
+
+ *ptr = *valuePtr;
+ *(ptr + 1) = *(valuePtr + 1);
+ *(ptr + 2) = *(valuePtr + 2);
+ *(ptr + 3) = *(valuePtr + 3);
+ }
+
+ this.Position += 4;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(string value)
+ {
+ var bytes = UTF8Encoding.UTF8.GetBytes(value);
+ this.WritePacked(bytes.Length);
+ this.Write(bytes);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes)
+ {
+ this.WritePacked((uint)bytes.Length);
+ this.Write(bytes);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes, int length)
+ {
+ this.WritePacked((uint)length);
+ this.Write(bytes, length);
+ }
+
+ public void WriteBytesAndSize(byte[] bytes, int offset, int length)
+ {
+ this.WritePacked((uint)length);
+ this.Write(bytes, offset, length);
+ }
+
+ public void Write(byte[] bytes)
+ {
+ Array.Copy(bytes, 0, this.Buffer, this.Position, bytes.Length);
+ this.Position += bytes.Length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte[] bytes, int offset, int length)
+ {
+ Array.Copy(bytes, offset, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ public void Write(byte[] bytes, int length)
+ {
+ Array.Copy(bytes, 0, this.Buffer, this.Position, length);
+ this.Position += length;
+ if (this.Position > this.Length) this.Length = this.Position;
+ }
+
+ ///
+ public void WritePacked(int value)
+ {
+ this.WritePacked((uint)value);
+ }
+
+ ///
+ public void WritePacked(uint value)
+ {
+ do
+ {
+ byte b = (byte)(value & 0xFF);
+ if (value >= 0x80)
+ {
+ b |= 0x80;
+ }
+
+ this.Write(b);
+ value >>= 7;
+ } while (value > 0);
+ }
+ #endregion
+
+ public void Write(MessageWriter msg, bool includeHeader)
+ {
+ int offset = 0;
+ if (!includeHeader)
+ {
+ switch (msg.SendOption)
+ {
+ case SendOption.None:
+ offset = 1;
+ break;
+ case SendOption.Reliable:
+ offset = 3;
+ break;
+ }
+ }
+
+ this.Write(msg.Buffer, offset, msg.Length - offset);
+ }
+
+ public unsafe static bool IsLittleEndian()
+ {
+ byte b;
+ unsafe
+ {
+ int i = 1;
+ byte* bp = (byte*)&i;
+ b = *bp;
+ }
+
+ return b == 1;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnection.cs b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs
new file mode 100644
index 0000000..d1de8a8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NetworkConnection.cs
@@ -0,0 +1,117 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Text;
+
+
+namespace Hazel
+{
+ public enum HazelInternalErrors
+ {
+ SocketExceptionSend,
+ SocketExceptionReceive,
+ ReceivedZeroBytes,
+ PingsWithoutResponse,
+ ReliablePacketWithoutResponse,
+ ConnectionDisconnected,
+ DtlsNegotiationFailed
+ }
+
+ /// <summary>
+ /// Abstract base class for a <see cref="Connection"/> to a remote end point via a network protocol like TCP or UDP.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class NetworkConnection : Connection
+ {
+ /// <summary>
+ /// An event that gives us a chance to send well-formed disconnect messages to clients when an internal disconnect happens.
+ /// </summary>
+ public Func<HazelInternalErrors, MessageWriter> OnInternalDisconnect;
+
+ public virtual float AveragePingMs { get; }
+
+ public long GetIP4Address()
+ {
+ if (IPMode == IPMode.IPv4)
+ {
+ return this.EndPoint.Address.Address;
+ }
+ else
+ {
+ var bytes = this.EndPoint.Address.GetAddressBytes();
+ return BitConverter.ToInt64(bytes, bytes.Length - 8);
+ }
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected abstract bool SendDisconnect(MessageWriter writer);
+
+ /// <summary>
+ /// Called when the socket has been disconnected at the remote host.
+ /// </summary>
+ protected void DisconnectRemote(string reason, MessageReader reader)
+ {
+ if (this.SendDisconnect(null))
+ {
+ try
+ {
+ InvokeDisconnected(reason, reader);
+ }
+ catch { }
+ }
+
+ this.Dispose();
+ }
+
+ /// <summary>
+ /// Called when socket is disconnected internally
+ /// </summary>
+ internal void DisconnectInternal(HazelInternalErrors error, string reason)
+ {
+ var handler = this.OnInternalDisconnect;
+ if (handler != null)
+ {
+ MessageWriter messageToRemote = handler(error);
+ if (messageToRemote != null)
+ {
+ try
+ {
+ Disconnect(reason, messageToRemote);
+ }
+ finally
+ {
+ messageToRemote.Recycle();
+ }
+ }
+ else
+ {
+ Disconnect(reason);
+ }
+ }
+ else
+ {
+ Disconnect(reason);
+ }
+ }
+
+ /// <summary>
+ /// Called when the socket has been disconnected locally.
+ /// </summary>
+ public override void Disconnect(string reason, MessageWriter writer = null)
+ {
+ if (this.SendDisconnect(writer))
+ {
+ try
+ {
+ InvokeDisconnected(reason, null);
+ }
+ catch { }
+ }
+
+ this.Dispose();
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs
new file mode 100644
index 0000000..af26c4c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NetworkConnectionListener.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Text;
+
+
+namespace Hazel
+{
+ /// <summary>
+ /// Abstract base class for a <see cref="ConnectionListener"/> for network based connections.
+ /// </summary>
+ /// <threadsafety static="true" instance="true"/>
+ public abstract class NetworkConnectionListener : ConnectionListener
+ {
+ /// <summary>
+ /// The local end point the listener is listening for new clients on.
+ /// </summary>
+ public IPEndPoint EndPoint { get; protected set; }
+
+ /// <summary>
+ /// The <see cref="IPMode">IPMode</see> the listener is listening for new clients on.
+ /// </summary>
+ public IPMode IPMode { get; protected set; }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs
new file mode 100644
index 0000000..c3fd62f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/NewConnectionEventArgs.cs
@@ -0,0 +1,22 @@
+namespace Hazel
+{
+ public struct NewConnectionEventArgs
+ {
+ /// <summary>
+ /// The data received from the client in the handshake.
+ /// You must not recycle this. If you need the message outside of a callback, you should copy it.
+ /// </summary>
+ public readonly MessageReader HandshakeData;
+
+ /// <summary>
+ /// The <see cref="Connection"/> to the new client.
+ /// </summary>
+ public readonly Connection Connection;
+
+ public NewConnectionEventArgs(MessageReader handshakeData, Connection connection)
+ {
+ this.HandshakeData = handshakeData;
+ this.Connection = connection;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/ObjectPool.cs b/Tools/Hazel-Networking/Hazel/ObjectPool.cs
new file mode 100644
index 0000000..510e55a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/ObjectPool.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Threading;
+
+namespace Hazel
+{
+ /// <summary>
+ /// A fairly simple object pool for items that will be created a lot.
+ /// </summary>
+ /// <typeparam name="T">The type that is pooled.</typeparam>
+ /// <threadsafety static="true" instance="true"/>
+ public sealed class ObjectPool<T> where T : IRecyclable
+ {
+ private int numberCreated;
+ public int NumberCreated { get { return numberCreated; } }
+
+ public int NumberInUse { get { return this.inuse.Count; } }
+ public int NumberNotInUse { get { return this.pool.Count; } }
+ public int Size { get { return this.NumberInUse + this.NumberNotInUse; } }
+
+#if HAZEL_BAG
+ private readonly ConcurrentBag<T> pool = new ConcurrentBag<T>();
+#else
+ private readonly List<T> pool = new List<T>();
+#endif
+
+ // Unavailable objects
+ private readonly ConcurrentDictionary<T, bool> inuse = new ConcurrentDictionary<T, bool>();
+
+ /// <summary>
+ /// The generator for creating new objects.
+ /// </summary>
+ /// <returns></returns>
+ private readonly Func<T> objectFactory;
+
+ /// <summary>
+ /// Internal constructor for our ObjectPool.
+ /// </summary>
+ internal ObjectPool(Func<T> objectFactory)
+ {
+ this.objectFactory = objectFactory;
+ }
+
+ /// <summary>
+ /// Returns a pooled object of type T, if none are available another is created.
+ /// </summary>
+ /// <returns>An instance of T.</returns>
+ internal T GetObject()
+ {
+#if HAZEL_BAG
+ if (!pool.TryTake(out T item))
+ {
+ Interlocked.Increment(ref numberCreated);
+ item = objectFactory.Invoke();
+ }
+#else
+ T item;
+ lock (this.pool)
+ {
+ if (this.pool.Count > 0)
+ {
+ var idx = this.pool.Count - 1;
+ item = this.pool[idx];
+ this.pool.RemoveAt(idx);
+ }
+ else
+ {
+ Interlocked.Increment(ref numberCreated);
+ item = objectFactory.Invoke();
+ }
+ }
+#endif
+
+ if (!inuse.TryAdd(item, true))
+ {
+ throw new Exception("Duplicate pull " + typeof(T).Name);
+ }
+
+ return item;
+ }
+
+ /// <summary>
+ /// Returns an object to the pool.
+ /// </summary>
+ /// <param name="item">The item to return.</param>
+ internal void PutObject(T item)
+ {
+ if (inuse.TryRemove(item, out bool b))
+ {
+#if HAZEL_BAG
+ pool.Add(item);
+#else
+ lock (this.pool)
+ {
+ pool.Add(item);
+ }
+#endif
+ }
+ else
+ {
+#if DEBUG
+ throw new Exception("Duplicate add " + typeof(T).Name);
+#endif
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/SendErrors.cs b/Tools/Hazel-Networking/Hazel/SendErrors.cs
new file mode 100644
index 0000000..6871c6a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/SendErrors.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ [Flags]
+ public enum SendErrors
+ {
+ None,
+ Disconnected,
+ Unknown
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/SendOption.cs b/Tools/Hazel-Networking/Hazel/SendOption.cs
new file mode 100644
index 0000000..c2ffb22
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/SendOption.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Hazel
+{
+ /// <summary>
+ /// Specifies how a message should be sent between connections.
+ /// </summary>
+ [Flags]
+ public enum SendOption : byte
+ {
+ /// <summary>
+ /// Requests unreliable delivery with no framentation.
+ /// </summary>
+ /// <remarks>
+ /// Sending data using unreliable delivery means that data is not guaranteed to arrive at it's destination nor is
+ /// it guarenteed to arrive only once. However, unreliable delivery can be faster than other methods and it
+ /// typically requires a smaller number of protocol bytes than other methods. There is also typically less
+ /// processing involved and less memory needed as packets are not stored once sent.
+ /// </remarks>
+ None = 0,
+
+ /// <summary>
+ /// Requests data be sent reliably but with no fragmentation.
+ /// </summary>
+ /// <remarks>
+ /// Sending data reliably means that data is guarenteed to arrive and to arrive only once. Reliable delivery
+ /// typically requires more processing, more memory (as packets need to be stored in case they need resending),
+ /// a larger number of protocol bytes and can be slower than unreliable delivery.
+ /// </remarks>
+ Reliable = 1,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs
new file mode 100644
index 0000000..3c7abcf
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/ILogger.cs
@@ -0,0 +1,65 @@
+using System;
+
+namespace Hazel
+{
+ public interface ILogger
+ {
+ void WriteVerbose(string msg);
+ void WriteError(string msg);
+ void WriteWarning(string msg);
+ void WriteInfo(string msg);
+ }
+
+ public class NullLogger : ILogger
+ {
+ public static readonly NullLogger Instance = new NullLogger();
+
+ public void WriteVerbose(string msg)
+ {
+ }
+
+ public void WriteError(string msg)
+ {
+ }
+
+ public void WriteWarning(string msg)
+ {
+ }
+
+ public void WriteInfo(string msg)
+ {
+ }
+ }
+
+ public class ConsoleLogger : ILogger
+ {
+ private bool verbose;
+ public ConsoleLogger(bool verbose)
+ {
+ this.verbose = verbose;
+ }
+
+ public void WriteVerbose(string msg)
+ {
+ if (this.verbose)
+ {
+ Console.WriteLine($"{DateTime.Now} [VERBOSE] {msg}");
+ }
+ }
+
+ public void WriteWarning(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [WARN] {msg}");
+ }
+
+ public void WriteError(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [ERROR] {msg}");
+ }
+
+ public void WriteInfo(string msg)
+ {
+ Console.WriteLine($"{DateTime.Now} [INFO] {msg}");
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs
new file mode 100644
index 0000000..d856823
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/NetUtility.cs
@@ -0,0 +1,158 @@
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.NetworkInformation;
+using System.Net.Sockets;
+
+namespace Hazel.UPnP
+{
+ internal class NetUtility
+ {
+ private static IList<NetworkInterface> GetValidNetworkInterfaces()
+ {
+ var nics = NetworkInterface.GetAllNetworkInterfaces();
+ if (nics == null || nics.Length < 1)
+ return new NetworkInterface[0];
+
+ var validInterfaces = new List<NetworkInterface>(nics.Length);
+
+ NetworkInterface best = null;
+ foreach (NetworkInterface adapter in nics)
+ {
+ if (adapter.NetworkInterfaceType == NetworkInterfaceType.Loopback || adapter.NetworkInterfaceType == NetworkInterfaceType.Unknown)
+ continue;
+ if (!adapter.Supports(NetworkInterfaceComponent.IPv4) && !adapter.Supports(NetworkInterfaceComponent.IPv6))
+ continue;
+ if (best == null)
+ best = adapter;
+ if (adapter.OperationalStatus != OperationalStatus.Up)
+ continue;
+
+ // make sure this adapter has any ip addresses
+ IPInterfaceProperties properties = adapter.GetIPProperties();
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null)
+ {
+ // Yes it does, add this network interface.
+ validInterfaces.Add(adapter);
+ break;
+ }
+ }
+ }
+
+ if (validInterfaces.Count == 0 && best != null)
+ validInterfaces.Add(best);
+
+ return validInterfaces;
+ }
+
+ /// <summary>
+ /// Gets the addresses from all active network interfaces, but at most one per interface.
+ /// </summary>
+ /// <param name="addressFamily">The <see cref="AddressFamily"/> of the addresses to return</param>
+ /// <returns>An <see cref="ICollection{T}"/> of <see cref="UnicastIPAddressInformation"/>.</returns>
+ public static ICollection<UnicastIPAddressInformation> GetAddressesFromNetworkInterfaces(AddressFamily addressFamily)
+ {
+ var unicastAddresses = new List<UnicastIPAddressInformation>();
+
+ foreach (NetworkInterface adapter in GetValidNetworkInterfaces())
+ {
+ IPInterfaceProperties properties = adapter.GetIPProperties();
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == addressFamily)
+ {
+ unicastAddresses.Add(unicastAddress);
+ break;
+ }
+ }
+ }
+
+ return unicastAddresses;
+ }
+
+ /// <summary>
+ /// Gets my local IPv4 address (not necessarily external) and subnet mask
+ /// </summary>
+ public static IPAddress GetMyAddress(out IPAddress mask)
+ {
+ var networkInterfaces = GetValidNetworkInterfaces();
+ IPInterfaceProperties properties = null;
+
+ if (networkInterfaces.Count > 0)
+ properties = networkInterfaces[0]?.GetIPProperties();
+
+ if (properties != null)
+ {
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork)
+ {
+ mask = unicastAddress.IPv4Mask;
+ return unicastAddress.Address;
+ }
+ }
+ }
+
+ mask = null;
+ return null;
+ }
+
+ /// <summary>
+ /// Gets the broadcast address for the first network interface or, if not able to,
+ /// the limited broadcast address.
+ /// </summary>
+ /// <returns>An <see cref="IPAddress"/> for broadcasting.</returns>
+ public static IPAddress GetBroadcastAddress()
+ {
+ var networkInterfaces = GetValidNetworkInterfaces();
+ IPInterfaceProperties properties = null;
+
+ if (networkInterfaces.Count > 0)
+ properties = networkInterfaces[0]?.GetIPProperties();
+
+ if (properties != null)
+ {
+ foreach (UnicastIPAddressInformation unicastAddress in properties.UnicastAddresses)
+ {
+ IPAddress ipAddress = GetBroadcastAddress(unicastAddress);
+ if (ipAddress != null)
+ {
+ return ipAddress;
+ }
+ }
+ }
+
+ return IPAddress.Broadcast;
+ }
+
+ /// <summary>
+ /// Gets the broadcast address for the given <paramref name="unicastAddress"/>.
+ /// </summary>
+ /// <param name="unicastAddress">A <see cref="UnicastIPAddressInformation"/></param>
+ /// <returns>An <see cref="IPAddress"/> for broadcasting, null if the <paramref name="unicastAddress"/>
+ /// is not an IPv4 address.</returns>
+ public static IPAddress GetBroadcastAddress(UnicastIPAddressInformation unicastAddress)
+ {
+ if (unicastAddress != null && unicastAddress.Address != null && unicastAddress.Address.AddressFamily == AddressFamily.InterNetwork)
+ {
+ var mask = unicastAddress.IPv4Mask;
+ byte[] ipAdressBytes = unicastAddress.Address.GetAddressBytes();
+ byte[] subnetMaskBytes = mask.GetAddressBytes();
+
+ if (ipAdressBytes.Length != subnetMaskBytes.Length)
+ throw new ArgumentException("Lengths of IP address and subnet mask do not match.");
+
+ byte[] broadcastAddress = new byte[ipAdressBytes.Length];
+ for (int i = 0; i < broadcastAddress.Length; i++)
+ {
+ broadcastAddress[i] = (byte)(ipAdressBytes[i] | (subnetMaskBytes[i] ^ 255));
+ }
+ return new IPAddress(broadcastAddress);
+ }
+
+ return null;
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs
new file mode 100644
index 0000000..771709e
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/UPnP/UPnPHelper.cs
@@ -0,0 +1,347 @@
+using System;
+using System.IO;
+using System.Xml;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.UPnP
+{
+ /// <summary>
+ /// Status of the UPnP capabilities
+ /// </summary>
+ public enum UPnPStatus
+ {
+ /// <summary>
+ /// Still discovering UPnP capabilities
+ /// </summary>
+ Discovering,
+
+ /// <summary>
+ /// UPnP is not available
+ /// </summary>
+ NotAvailable,
+
+ /// <summary>
+ /// UPnP is available and ready to use
+ /// </summary>
+ Available
+ }
+
+ public class UPnPHelper : IDisposable
+ {
+ private const int DiscoveryTimeOutMs = 1000;
+
+ private string serviceUrl;
+ private string serviceName = "";
+
+ private ManualResetEvent discoveryComplete = new ManualResetEvent(false);
+ private Socket socket;
+
+ private DateTime discoveryResponseDeadline;
+
+ private EndPoint ep;
+ private byte[] buffer;
+
+ private ILogger logger;
+
+ /// <summary>
+ /// Status of the UPnP capabilities of this NetPeer
+ /// </summary>
+ public UPnPStatus Status { get; private set; }
+
+ public UPnPHelper(ILogger logger)
+ {
+ this.logger = logger;
+
+ this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.socket.EnableBroadcast = true;
+ this.socket.MulticastLoopback = false;
+
+ this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, 1);
+ this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+
+ this.ep = new IPEndPoint(IPAddress.Any, 1900);
+ this.buffer = new byte[ushort.MaxValue];
+
+ ListenForUPnP();
+
+ this.Discover();
+ }
+
+ private void ListenForUPnP()
+ {
+ try
+ {
+ socket.BeginReceiveFrom(this.buffer, 0, this.buffer.Length, SocketFlags.None, ref ep, HandleMessage, null);
+ }
+ catch(Exception e)
+ {
+ this.logger.WriteInfo("Exception listening for UPnP: " + e.Message);
+ }
+ }
+
+ private void HandleMessage(IAsyncResult ar)
+ {
+ int len;
+ try
+ {
+ len = this.socket.EndReceiveFrom(ar, ref ep);
+ }
+ catch
+ {
+ return;
+ }
+
+ string resp = System.Text.Encoding.UTF8.GetString(buffer, 0, len);
+ if (resp.Contains("upnp:rootdevice") || resp.Contains("UPnP/1.0"))
+ {
+ var locationStart = resp.IndexOf("location:", StringComparison.OrdinalIgnoreCase);
+ if (locationStart >= 0)
+ {
+ locationStart += 10;
+ var locationEnd = resp.IndexOf("\r", locationStart);
+
+ resp = resp.Substring(locationStart, locationEnd - locationStart);
+ if (!ExtractServiceUrl(resp))
+ {
+ ListenForUPnP();
+ }
+ }
+ else
+ {
+ ListenForUPnP();
+ }
+ }
+ else
+ {
+ ListenForUPnP();
+ }
+ }
+
+ internal void Discover()
+ {
+ string str =
+"M-SEARCH * HTTP/1.1\r\n" +
+"HOST: 239.255.255.250:1900\r\n" +
+"ST:upnp:rootdevice\r\n" +
+"MAN:\"ssdp:discover\"\r\n" +
+"MX:3\r\n\r\n";
+
+ discoveryResponseDeadline = DateTime.UtcNow.AddSeconds(6);
+ Status = UPnPStatus.Discovering;
+
+ byte[] buffer = System.Text.Encoding.UTF8.GetBytes(str);
+
+ this.logger.WriteInfo("Attempting UPnP discovery");
+
+ socket.SendTo(buffer, new IPEndPoint(NetUtility.GetBroadcastAddress(), 1900));
+ }
+
+ internal bool ExtractServiceUrl(string resp)
+ {
+ try
+ {
+ XmlDocument desc = new XmlDocument();
+ using (var response = WebRequest.Create(resp).GetResponse())
+ {
+ desc.Load(response.GetResponseStream());
+ }
+
+ XmlNamespaceManager nsMgr = new XmlNamespaceManager(desc.NameTable);
+ nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0");
+ XmlNode typen = desc.SelectSingleNode("//tns:device/tns:deviceType/text()", nsMgr);
+ if (!typen.Value.Contains("InternetGatewayDevice"))
+ return false;
+
+ serviceName = "WANIPConnection";
+ XmlNode node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr);
+ if (node == null)
+ {
+ //try another service name
+ serviceName = "WANPPPConnection";
+ node = desc.SelectSingleNode("//tns:service[tns:serviceType=\"urn:schemas-upnp-org:service:" + serviceName + ":1\"]/tns:controlURL/text()", nsMgr);
+ if (node == null)
+ return false;
+ }
+
+ serviceUrl = CombineUrls(resp, node.Value);
+ this.logger.WriteInfo("UPnP service ready");
+ Status = UPnPStatus.Available;
+ discoveryComplete.Set();
+ return true;
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("Exception while parsing UPnP Service URL: " + e.Message);
+ return false;
+ }
+ }
+
+ private static string CombineUrls(string gatewayURL, string subURL)
+ {
+ // Is Control URL an absolute URL?
+ if (subURL.Contains("http:") || subURL.Contains("."))
+ return subURL;
+
+ gatewayURL = gatewayURL.Replace("http://", ""); // strip any protocol
+ int n = gatewayURL.IndexOf("/");
+ if (n >= 0)
+ {
+ gatewayURL = gatewayURL.Substring(0, n); // Use first portion of URL
+ }
+
+ return "http://" + gatewayURL + subURL;
+ }
+
+ private bool CheckAvailability()
+ {
+ switch (Status)
+ {
+ case UPnPStatus.NotAvailable:
+ return false;
+ case UPnPStatus.Available:
+ return true;
+ case UPnPStatus.Discovering:
+ while (!discoveryComplete.WaitOne(DiscoveryTimeOutMs))
+ {
+ if (DateTime.UtcNow > discoveryResponseDeadline)
+ {
+ Status = UPnPStatus.NotAvailable;
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+ }
+
+ /// <summary>
+ /// Add a forwarding rule to the router using UPnP
+ /// </summary>
+ /// <param name="externalPort">The external, WAN facing, port</param>
+ /// <param name="description">A description for the port forwarding rule</param>
+ /// <param name="internalPort">The port on the client machine to send traffic to</param>
+ /// <param name="durationSeconds">The lease duration on the port forwarding rule, in seconds. 0 for indefinite.</param>
+ public bool ForwardPort(int externalPort, string description, int internalPort = 0, int durationSeconds = 0)
+ {
+ if (!CheckAvailability())
+ return false;
+
+ if (internalPort == 0)
+ internalPort = externalPort;
+
+ try
+ {
+ var client = NetUtility.GetMyAddress(out _);
+ if (client == null)
+ return false;
+
+ SOAPRequest(serviceUrl,
+ $"<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" +
+ "<NewRemoteHost></NewRemoteHost>" +
+ $"<NewExternalPort>{externalPort}</NewExternalPort>" +
+ "<NewProtocol>UDP</NewProtocol>" +
+ $"<NewInternalPort>{internalPort}</NewInternalPort>" +
+ $"<NewInternalClient>{client}</NewInternalClient>" +
+ "<NewEnabled>1</NewEnabled>" +
+ $"<NewPortMappingDescription>{description}</NewPortMappingDescription>" +
+ $"<NewLeaseDuration>{durationSeconds}</NewLeaseDuration>" +
+ "</u:AddPortMapping>",
+ "AddPortMapping");
+
+ this.logger.WriteInfo("Sent UPnP port forward request.");
+ return true;
+ }
+ catch (Exception ex)
+ {
+ this.logger.WriteError("UPnP port forward failed: " + ex.Message);
+ return false;
+ }
+ }
+
+ /// <summary>
+ /// Delete a forwarding rule from the router using UPnP
+ /// </summary>
+ /// <param name="externalPort">The external, 'internet facing', port</param>
+ public bool DeleteForwardingRule(int externalPort)
+ {
+ if (!CheckAvailability())
+ return false;
+
+ try
+ {
+ SOAPRequest(serviceUrl,
+ $"<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:{serviceName}:1\">" +
+ "<NewRemoteHost></NewRemoteHost>" +
+ $"<NewExternalPort>{externalPort}</NewExternalPort>" +
+ $"<NewProtocol>UDP</NewProtocol>" +
+ "</u:DeletePortMapping>", "DeletePortMapping");
+ return true;
+ }
+ catch (Exception ex)
+ {
+ // m_peer.LogWarning("UPnP delete forwarding rule failed: " + ex.Message);
+ return false;
+ }
+ }
+
+ /// <summary>
+ /// Retrieve the extern ip using UPnP
+ /// </summary>
+ public IPAddress GetExternalIP()
+ {
+ if (!CheckAvailability())
+ return null;
+ try
+ {
+ XmlDocument xdoc = SOAPRequest(serviceUrl, "<u:GetExternalIPAddress xmlns:u=\"urn:schemas-upnp-org:service:" + serviceName + ":1\">" +
+ "</u:GetExternalIPAddress>", "GetExternalIPAddress");
+ XmlNamespaceManager nsMgr = new XmlNamespaceManager(xdoc.NameTable);
+ nsMgr.AddNamespace("tns", "urn:schemas-upnp-org:device-1-0");
+ string IP = xdoc.SelectSingleNode("//NewExternalIPAddress/text()", nsMgr).Value;
+ return IPAddress.Parse(IP);
+ }
+ catch (Exception ex)
+ {
+ // m_peer.LogWarning("Failed to get external IP: " + ex.Message);
+ return null;
+ }
+ }
+
+ private XmlDocument SOAPRequest(string url, string soap, string function)
+ {
+ string req =
+"<?xml version=\"1.0\"?>" +
+"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">" +
+$"<s:Body>{soap}</s:Body>" +
+"</s:Envelope>";
+
+ WebRequest r = HttpWebRequest.Create(url);
+ r.Headers.Add("SOAPACTION", $"\"urn:schemas-upnp-org:service:{serviceName}:1#{function}\"");
+ r.ContentType = "text/xml; charset=\"utf-8\"";
+ r.Method = "POST";
+
+ byte[] b = System.Text.Encoding.UTF8.GetBytes(req);
+ r.ContentLength = b.Length;
+ r.GetRequestStream().Write(b, 0, b.Length);
+
+ using (WebResponse wres = r.GetResponse())
+ {
+ XmlDocument resp = new XmlDocument();
+ Stream ress = wres.GetResponseStream();
+ resp.Load(ress);
+ return resp;
+ }
+ }
+
+ public void Dispose()
+ {
+ this.discoveryComplete.Dispose();
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ this.socket.Dispose();
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs
new file mode 100644
index 0000000..74786d8
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/SendOptionInternal.cs
@@ -0,0 +1,39 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Extra internal states for SendOption enumeration when using UDP.
+ /// </summary>
+ public enum UdpSendOption : byte
+ {
+ /// <summary>
+ /// Hello message for initiating communication.
+ /// </summary>
+ Hello = 8,
+
+ /// <summary>
+ /// A single byte of continued existence
+ /// </summary>
+ Ping = 12,
+
+ /// <summary>
+ /// Message for discontinuing communication.
+ /// </summary>
+ Disconnect = 9,
+
+ /// <summary>
+ /// Message acknowledging the receipt of a message.
+ /// </summary>
+ Acknowledgement = 10,
+
+ /// <summary>
+ /// Message that is part of a larger, fragmented message.
+ /// </summary>
+ Fragment = 11,
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs
new file mode 100644
index 0000000..13b8d0b
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcastListener.cs
@@ -0,0 +1,157 @@
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ public class BroadcastPacket
+ {
+ public string Data;
+ public DateTime ReceiveTime;
+ public IPEndPoint Sender;
+
+ public BroadcastPacket(string data, IPEndPoint sender)
+ {
+ this.Data = data;
+ this.Sender = sender;
+ this.ReceiveTime = DateTime.Now;
+ }
+
+ public string GetAddress()
+ {
+ return this.Sender.Address.ToString();
+ }
+ }
+
+ public class UdpBroadcastListener : IDisposable
+ {
+ private Socket socket;
+ private EndPoint endpoint;
+ private Action<string> logger;
+
+ private byte[] buffer = new byte[1024];
+
+ private List<BroadcastPacket> packets = new List<BroadcastPacket>();
+
+ public bool Running { get; private set; }
+
+ ///
+ public UdpBroadcastListener(int port, Action<string> logger = null)
+ {
+ this.logger = logger;
+ this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.socket.EnableBroadcast = true;
+ this.socket.MulticastLoopback = false;
+ this.endpoint = new IPEndPoint(IPAddress.Any, port);
+ this.socket.Bind(this.endpoint);
+ }
+
+ ///
+ public void StartListen()
+ {
+ if (this.Running) return;
+ this.Running = true;
+
+ try
+ {
+ EndPoint endpt = new IPEndPoint(IPAddress.Any, 0);
+ this.socket.BeginReceiveFrom(buffer, 0, buffer.Length, SocketFlags.None, ref endpt, this.HandleData, null);
+ }
+ catch (NullReferenceException) { }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ this.Dispose();
+ }
+ }
+
+ private void HandleData(IAsyncResult result)
+ {
+ this.Running = false;
+
+ int numBytes;
+ EndPoint endpt = new IPEndPoint(IPAddress.Any, 0);
+ try
+ {
+ numBytes = this.socket.EndReceiveFrom(result, ref endpt);
+ }
+ catch (NullReferenceException)
+ {
+ // Already disposed
+ return;
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ this.Dispose();
+ return;
+ }
+
+ if (numBytes < 3
+ || buffer[0] != 4 || buffer[1] != 2)
+ {
+ this.StartListen();
+ return;
+ }
+
+ IPEndPoint ipEnd = (IPEndPoint)endpt;
+ string data = UTF8Encoding.UTF8.GetString(buffer, 2, numBytes - 2);
+ int dataHash = data.GetHashCode();
+
+ lock (packets)
+ {
+ bool found = false;
+ for (int i = 0; i < this.packets.Count; ++i)
+ {
+ var pkt = this.packets[i];
+ if (pkt == null || pkt.Data == null)
+ {
+ this.packets.RemoveAt(i);
+ i--;
+ continue;
+ }
+
+ if (pkt.Data.GetHashCode() == dataHash
+ && pkt.Sender.Equals(ipEnd))
+ {
+ this.packets[i].ReceiveTime = DateTime.Now;
+ break;
+ }
+ }
+
+ if (!found)
+ {
+ this.packets.Add(new BroadcastPacket(data, ipEnd));
+ }
+ }
+
+ this.StartListen();
+ }
+
+ ///
+ public BroadcastPacket[] GetPackets()
+ {
+ lock (this.packets)
+ {
+ var output = this.packets.ToArray();
+ this.packets.Clear();
+ return output;
+ }
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (this.socket != null)
+ {
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+ this.socket = null;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs
new file mode 100644
index 0000000..8877f86
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpBroadcaster.cs
@@ -0,0 +1,127 @@
+using Hazel.UPnP;
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+
+namespace Hazel.Udp
+{
+ public class UdpBroadcaster : IDisposable
+ {
+ private SocketBroadcast[] socketBroadcasts;
+ private byte[] data;
+ private Action<string> logger;
+
+ ///
+ public UdpBroadcaster(int port, Action<string> logger = null)
+ {
+ this.logger = logger;
+
+ var addresses = NetUtility.GetAddressesFromNetworkInterfaces(AddressFamily.InterNetwork);
+ this.socketBroadcasts = new SocketBroadcast[addresses.Count > 0 ? addresses.Count : 1];
+
+ int count = 0;
+ foreach (var addressInformation in addresses)
+ {
+ Socket socket = CreateSocket(new IPEndPoint(addressInformation.Address, 0));
+ IPAddress broadcast = NetUtility.GetBroadcastAddress(addressInformation);
+
+ this.socketBroadcasts[count] = new SocketBroadcast(socket, new IPEndPoint(broadcast, port));
+ count++;
+ }
+ if (count == 0)
+ {
+ Socket socket = CreateSocket(new IPEndPoint(IPAddress.Any, 0));
+
+ this.socketBroadcasts[0] = new SocketBroadcast(socket, new IPEndPoint(IPAddress.Broadcast, port));
+ }
+ }
+
+ private static Socket CreateSocket(IPEndPoint endPoint)
+ {
+ var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ socket.EnableBroadcast = true;
+ socket.MulticastLoopback = false;
+ socket.Bind(endPoint);
+
+ return socket;
+ }
+
+ ///
+ public void SetData(string data)
+ {
+ int len = UTF8Encoding.UTF8.GetByteCount(data);
+ this.data = new byte[len + 2];
+ this.data[0] = 4;
+ this.data[1] = 2;
+
+ UTF8Encoding.UTF8.GetBytes(data, 0, data.Length, this.data, 2);
+ }
+
+ ///
+ public void Broadcast()
+ {
+ if (this.data == null)
+ {
+ return;
+ }
+
+ foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts)
+ {
+ try
+ {
+ Socket socket = socketBroadcast.Socket;
+ socket.BeginSendTo(data, 0, data.Length, SocketFlags.None, socketBroadcast.Broadcast, this.FinishSendTo, socket);
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ }
+ }
+ }
+
+ private void FinishSendTo(IAsyncResult evt)
+ {
+ try
+ {
+ Socket socket = (Socket)evt.AsyncState;
+ socket.EndSendTo(evt);
+ }
+ catch (Exception e)
+ {
+ this.logger?.Invoke("BroadcastListener: " + e);
+ }
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (this.socketBroadcasts != null)
+ {
+ foreach (SocketBroadcast socketBroadcast in this.socketBroadcasts)
+ {
+ Socket socket = socketBroadcast.Socket;
+ if (socket != null)
+ {
+ try { socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { socket.Close(); } catch { }
+ try { socket.Dispose(); } catch { }
+ }
+ }
+ Array.Clear(this.socketBroadcasts, 0, this.socketBroadcasts.Length);
+ }
+ }
+
+ private struct SocketBroadcast
+ {
+ public Socket Socket;
+ public IPEndPoint Broadcast;
+
+ public SocketBroadcast(Socket socket, IPEndPoint broadcast)
+ {
+ Socket = socket;
+ Broadcast = broadcast;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs
new file mode 100644
index 0000000..f6da329
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpClientConnection.cs
@@ -0,0 +1,364 @@
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a client's connection to a server that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ public sealed class UdpClientConnection : UdpConnection
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ /// <summary>
+ /// The socket we're connected via.
+ /// </summary>
+ private Socket socket;
+
+ /// <summary>
+ /// Reset event that is triggered when the connection is marked Connected.
+ /// </summary>
+ private ManualResetEvent connectWaitLock = new ManualResetEvent(false);
+
+ private Timer reliablePacketTimer;
+
+#if DEBUG
+ public event Action<byte[], int> DataSentRaw;
+ public event Action<byte[], int> DataReceivedRaw;
+#endif
+
+ /// <summary>
+ /// Creates a new UdpClientConnection.
+ /// </summary>
+ /// <param name="remoteEndPoint">A <see cref="NetworkEndPoint"/> to connect to.</param>
+ public UdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger)
+ {
+ this.EndPoint = remoteEndPoint;
+ this.IPMode = ipMode;
+
+ this.socket = CreateSocket(ipMode);
+
+ reliablePacketTimer = new Timer(ManageReliablePacketsInternal, null, 100, Timeout.Infinite);
+ this.InitializeKeepAliveTimer();
+ }
+
+ ~UdpClientConnection()
+ {
+ this.Dispose(false);
+ }
+
+ private void ManageReliablePacketsInternal(object state)
+ {
+ base.ManageReliablePackets();
+ try
+ {
+ reliablePacketTimer.Change(100, Timeout.Infinite);
+ }
+ catch { }
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+#if DEBUG
+ if (TestLagMs > 0)
+ {
+ ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); });
+ }
+ else
+#endif
+ {
+ WriteBytesToConnectionReal(bytes, length);
+ }
+ }
+
+ private void WriteBytesToConnectionReal(byte[] bytes, int length)
+ {
+#if DEBUG
+ DataSentRaw?.Invoke(bytes, length);
+#endif
+
+ try
+ {
+ this.Statistics.LogPacketSend(length);
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint,
+ HandleSendTo,
+ null);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ private void HandleSendTo(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ /// <inheritdoc />
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.ConnectAsync(bytes);
+
+ //Wait till hello packet is acknowledged and the state is set to Connected
+ bool timedOut = !WaitOnConnect(timeout);
+
+ //If we timed out raise an exception
+ if (timedOut)
+ {
+ Dispose();
+ throw new HazelException("Connection attempt timed out.");
+ }
+ }
+
+ /// <inheritdoc />
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connecting;
+
+ try
+ {
+ if (IPMode == IPMode.IPv4)
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ else
+ socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
+ }
+ catch (SocketException e)
+ {
+ this.State = ConnectionState.NotConnected;
+ throw new HazelException("A SocketException occurred while binding to the port.", e);
+ }
+
+ try
+ {
+ StartListeningForData();
+ }
+ catch (ObjectDisposedException)
+ {
+ // If the socket's been disposed then we can just end there but make sure we're in NotConnected state.
+ // If we end up here I'm really lost...
+ this.State = ConnectionState.NotConnected;
+ return;
+ }
+ catch (SocketException e)
+ {
+ Dispose();
+ throw new HazelException("A SocketException occurred while initiating a receive operation.", e);
+ }
+
+ // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present)
+ // When acknowledged set the state to connected
+ SendHello(bytes, () =>
+ {
+ this.State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ });
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ void StartListeningForData()
+ {
+#if DEBUG
+ if (this.TestLagMs > 0)
+ {
+ Thread.Sleep(this.TestLagMs);
+ }
+#endif
+
+ var msg = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ socket.BeginReceive(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ReadCallback, msg);
+ }
+ catch
+ {
+ msg.Recycle();
+ this.Dispose();
+ }
+ }
+
+ protected override void SetState(ConnectionState state)
+ {
+ try
+ {
+ // If the server disconnects you during the hello
+ // you can go straight from Connecting to NotConnected.
+ if (state == ConnectionState.Connected
+ || state == ConnectionState.NotConnected)
+ {
+ connectWaitLock.Set();
+ }
+ else
+ {
+ connectWaitLock.Reset();
+ }
+ }
+ catch (ObjectDisposedException)
+ {
+ }
+ }
+
+ /// <summary>
+ /// Blocks until the Connection is connected.
+ /// </summary>
+ /// <param name="timeout">The number of milliseconds to wait before timing out.</param>
+ public bool WaitOnConnect(int timeout)
+ {
+ return connectWaitLock.WaitOne(timeout);
+ }
+
+ /// <summary>
+ /// Called when data has been received by the socket.
+ /// </summary>
+ /// <param name="result">The asyncronous operation's result.</param>
+ void ReadCallback(IAsyncResult result)
+ {
+ var msg = (MessageReader)result.AsyncState;
+
+ try
+ {
+ msg.Length = socket.EndReceive(result);
+ }
+ catch (SocketException e)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message);
+ return;
+ }
+ catch (Exception)
+ {
+ msg.Recycle();
+ return;
+ }
+
+ //Exit if no bytes read, we've failed.
+ if (msg.Length == 0)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes");
+ return;
+ }
+
+ //Begin receiving again
+ try
+ {
+ StartListeningForData();
+ }
+ catch (SocketException e)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message);
+ }
+ catch (ObjectDisposedException)
+ {
+ //If the socket's been disposed then we can just end there.
+ return;
+ }
+
+#if DEBUG
+ if (this.TestDropRate > 0)
+ {
+ if ((this.testDropCount++ % this.TestDropRate) == 0)
+ {
+ return;
+ }
+ }
+
+ DataReceivedRaw?.Invoke(msg.Buffer, msg.Length);
+#endif
+ HandleReceive(msg, msg.Length);
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// You may include optional disconnect data. The SendOption must be unreliable.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state == ConnectionState.NotConnected) return false;
+ this.State = ConnectionState.NotConnected; // Use the property so we release the state lock
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ bytes.Length,
+ SocketFlags.None,
+ EndPoint);
+ }
+ catch { }
+
+ return true;
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.reliablePacketTimer.Dispose();
+ this.connectWaitLock.Dispose();
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs
new file mode 100644
index 0000000..75b4f1d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.KeepAlive.cs
@@ -0,0 +1,167 @@
+using System;
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ partial class UdpConnection
+ {
+
+ /// <summary>
+ /// Class to hold packet data
+ /// </summary>
+ public class PingPacket : IRecyclable
+ {
+ private static readonly ObjectPool<PingPacket> PacketPool = new ObjectPool<PingPacket>(() => new PingPacket());
+
+ public readonly Stopwatch Stopwatch = new Stopwatch();
+
+ internal static PingPacket GetObject()
+ {
+ return PacketPool.GetObject();
+ }
+
+ public void Recycle()
+ {
+ Stopwatch.Stop();
+ PacketPool.PutObject(this);
+ }
+ }
+
+ internal ConcurrentDictionary<ushort, PingPacket> activePingPackets = new ConcurrentDictionary<ushort, PingPacket>();
+
+ /// <summary>
+ /// The interval from data being received or transmitted to a keepalive packet being sent in milliseconds.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// Keepalive packets serve to close connections when an endpoint abruptly disconnects and to ensure than any
+ /// NAT devices do not close their translation for our argument. By ensuring there is regular contact the
+ /// connection can detect and prevent these issues.
+ /// </para>
+ /// <para>
+ /// The default value is 10 seconds, set to System.Threading.Timeout.Infinite to disable keepalive packets.
+ /// </para>
+ /// </remarks>
+ public int KeepAliveInterval
+ {
+ get
+ {
+ return keepAliveInterval;
+ }
+
+ set
+ {
+ keepAliveInterval = value;
+ ResetKeepAliveTimer();
+ }
+ }
+ private int keepAliveInterval = 1500;
+
+ public int MissingPingsUntilDisconnect { get; set; } = 6;
+ private volatile int pingsSinceAck = 0;
+
+ /// <summary>
+ /// The timer creating keepalive pulses.
+ /// </summary>
+ private Timer keepAliveTimer;
+
+ /// <summary>
+ /// Starts the keepalive timer.
+ /// </summary>
+ protected void InitializeKeepAliveTimer()
+ {
+ keepAliveTimer = new Timer(
+ HandleKeepAlive,
+ null,
+ keepAliveInterval,
+ keepAliveInterval
+ );
+ }
+
+ private void HandleKeepAlive(object state)
+ {
+ if (this.State != ConnectionState.Connected) return;
+
+ if (this.pingsSinceAck >= this.MissingPingsUntilDisconnect)
+ {
+ this.DisposeKeepAliveTimer();
+ this.DisconnectInternal(HazelInternalErrors.PingsWithoutResponse, $"Sent {this.pingsSinceAck} pings that remote has not responded to.");
+ return;
+ }
+
+ try
+ {
+ this.pingsSinceAck++;
+ SendPing();
+ }
+ catch
+ {
+ }
+ }
+
+ // Pings are special, quasi-reliable packets.
+ // We send them to trigger responses that validate our connection is alive
+ // An unacked ping should never be the sole cause of a disconnect.
+ // Rather, the responses will reset our pingsSinceAck, enough unacked
+ // pings should cause a disconnect.
+ private void SendPing()
+ {
+ ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated);
+
+ byte[] bytes = new byte[3];
+ bytes[0] = (byte)UdpSendOption.Ping;
+ bytes[1] = (byte)(id >> 8);
+ bytes[2] = (byte)id;
+
+ PingPacket pkt;
+ if (!this.activePingPackets.TryGetValue(id, out pkt))
+ {
+ pkt = PingPacket.GetObject();
+ if (!this.activePingPackets.TryAdd(id, pkt))
+ {
+ throw new Exception("This shouldn't be possible");
+ }
+ }
+
+ pkt.Stopwatch.Restart();
+
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogReliableSend(0);
+ }
+
+ /// <summary>
+ /// Resets the keepalive timer to zero.
+ /// </summary>
+ protected void ResetKeepAliveTimer()
+ {
+ try
+ {
+ keepAliveTimer?.Change(keepAliveInterval, keepAliveInterval);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Disposes of the keep alive timer.
+ /// </summary>
+ private void DisposeKeepAliveTimer()
+ {
+ if (this.keepAliveTimer != null)
+ {
+ this.keepAliveTimer.Dispose();
+ }
+
+ foreach (var kvp in activePingPackets)
+ {
+ if (this.activePingPackets.TryRemove(kvp.Key, out var pkt))
+ {
+ pkt.Recycle();
+ }
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs
new file mode 100644
index 0000000..bed4738
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.Reliable.cs
@@ -0,0 +1,490 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ partial class UdpConnection
+ {
+ private const int MinResendDelayMs = 50;
+ private const int MaxInitialResendDelayMs = 300;
+ private const int MaxAdditionalResendDelayMs = 1000;
+
+ public readonly ObjectPool<Packet> PacketPool;
+
+ /// <summary>
+ /// The starting timeout, in miliseconds, at which data will be resent.
+ /// </summary>
+ /// <remarks>
+ /// <para>
+ /// For reliable delivery data is resent at specified intervals unless an acknowledgement is received from the
+ /// receiving device. The ResendTimeout specifies the interval between the packets being resent, each time a packet
+ /// is resent the interval is increased for that packet until the duration exceeds the <see cref="DisconnectTimeoutMs"/> value.
+ /// </para>
+ /// <para>
+ /// Setting this to its default of 0 will mean the timeout is 2 times the value of the average ping, usually
+ /// resulting in a more dynamic resend that responds to endpoints on slower or faster connections.
+ /// </para>
+ /// </remarks>
+ public volatile int ResendTimeoutMs = 0;
+
+ /// <summary>
+ /// Max number of times to resend. 0 == no limit
+ /// </summary>
+ public volatile int ResendLimit = 0;
+
+ /// <summary>
+ /// A compounding multiplier to back off resend timeout.
+ /// Applied to ping before first timeout when ResendTimeout == 0.
+ /// </summary>
+ public volatile float ResendPingMultiplier = 2;
+
+ /// <summary>
+ /// Holds the last ID allocated.
+ /// </summary>
+ private int lastIDAllocated = -1;
+
+ /// <summary>
+ /// The packets of data that have been transmitted reliably and not acknowledged.
+ /// </summary>
+ internal ConcurrentDictionary<ushort, Packet> reliableDataPacketsSent = new ConcurrentDictionary<ushort, Packet>();
+
+ /// <summary>
+ /// Packet ids that have not been received, but are expected.
+ /// </summary>
+ private HashSet<ushort> reliableDataPacketsMissing = new HashSet<ushort>();
+
+ /// <summary>
+ /// The packet id that was received last.
+ /// </summary>
+ protected volatile ushort reliableReceiveLast = ushort.MaxValue;
+
+ private object PingLock = new object();
+
+ /// <summary>
+ /// Returns the average ping to this endpoint.
+ /// </summary>
+ /// <remarks>
+ /// This returns the average ping for a one-way trip as calculated from the reliable packets that have been sent
+ /// and acknowledged by the endpoint.
+ /// </remarks>
+ private float _pingMs = 500;
+
+ /// <summary>
+ /// The maximum times a message should be resent before marking the endpoint as disconnected.
+ /// </summary>
+ /// <remarks>
+ /// Reliable packets will be resent at an interval defined in <see cref="ResendTimeoutMs"/> for the number of times
+ /// specified here. Once a packet has been retransmitted this number of times and has not been acknowledged the
+ /// connection will be marked as disconnected and the <see cref="Connection.Disconnected">Disconnected</see> event
+ /// will be invoked.
+ /// </remarks>
+ public volatile int DisconnectTimeoutMs = 5000;
+
+ /// <summary>
+ /// Class to hold packet data
+ /// </summary>
+ public class Packet : IRecyclable
+ {
+ public ushort Id;
+ private byte[] Data;
+ private readonly UdpConnection Connection;
+ private int Length;
+
+ public int NextTimeoutMs;
+ public volatile bool Acknowledged;
+
+ public Action AckCallback;
+
+ public int Retransmissions;
+ public Stopwatch Stopwatch = new Stopwatch();
+
+ internal Packet(UdpConnection connection)
+ {
+ this.Connection = connection;
+ }
+
+ internal void Set(ushort id, byte[] data, int length, int timeout, Action ackCallback)
+ {
+ this.Id = id;
+ this.Data = data;
+ this.Length = length;
+
+ this.Acknowledged = false;
+ this.NextTimeoutMs = timeout;
+ this.AckCallback = ackCallback;
+ this.Retransmissions = 0;
+
+ this.Stopwatch.Restart();
+ }
+
+ // Packets resent
+ public int Resend()
+ {
+ var connection = this.Connection;
+ if (!this.Acknowledged && connection != null)
+ {
+ long lifetimeMs = this.Stopwatch.ElapsedMilliseconds;
+ if (lifetimeMs >= connection.DisconnectTimeoutMs)
+ {
+ if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self))
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {lifetimeMs}ms ({self.Retransmissions} resends)");
+
+ self.Recycle();
+ }
+
+ return 0;
+ }
+
+ if (lifetimeMs >= this.NextTimeoutMs)
+ {
+ ++this.Retransmissions;
+ if (connection.ResendLimit != 0
+ && this.Retransmissions > connection.ResendLimit)
+ {
+ if (connection.reliableDataPacketsSent.TryRemove(this.Id, out Packet self))
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ReliablePacketWithoutResponse, $"Reliable packet {self.Id} (size={this.Length}) was not ack'd after {self.Retransmissions} resends ({lifetimeMs}ms)");
+
+ self.Recycle();
+ }
+
+ return 0;
+ }
+
+ this.NextTimeoutMs += (int)Math.Min(this.NextTimeoutMs * connection.ResendPingMultiplier, MaxAdditionalResendDelayMs);
+ try
+ {
+ connection.WriteBytesToConnection(this.Data, this.Length);
+ connection.Statistics.LogMessageResent();
+ return 1;
+ }
+ catch (InvalidOperationException)
+ {
+ connection.DisconnectInternal(HazelInternalErrors.ConnectionDisconnected, "Could not resend data as connection is no longer connected");
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ /// <summary>
+ /// Returns this object back to the object pool from whence it came.
+ /// </summary>
+ public void Recycle()
+ {
+ this.Acknowledged = true;
+
+ this.Connection.PacketPool.PutObject(this);
+ }
+ }
+
+ internal int ManageReliablePackets()
+ {
+ int output = 0;
+ if (this.reliableDataPacketsSent.Count > 0)
+ {
+ foreach (var kvp in this.reliableDataPacketsSent)
+ {
+ Packet pkt = kvp.Value;
+
+ try
+ {
+ output += pkt.Resend();
+ }
+ catch { }
+ }
+ }
+
+ return output;
+ }
+
+ /// <summary>
+ /// Adds a 2 byte ID to the packet at offset and stores the packet reference for retransmission.
+ /// </summary>
+ /// <param name="buffer">The buffer to attach to.</param>
+ /// <param name="offset">The offset to attach at.</param>
+ /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param>
+ protected void AttachReliableID(byte[] buffer, int offset, Action ackCallback = null)
+ {
+ ushort id = (ushort)Interlocked.Increment(ref lastIDAllocated);
+
+ buffer[offset] = (byte)(id >> 8);
+ buffer[offset + 1] = (byte)id;
+
+ int resendDelayMs = this.ResendTimeoutMs;
+ if (resendDelayMs <= 0)
+ {
+ resendDelayMs = (_pingMs * this.ResendPingMultiplier).ClampToInt(MinResendDelayMs, MaxInitialResendDelayMs);
+ }
+
+ Packet packet = this.PacketPool.GetObject();
+ packet.Set(
+ id,
+ buffer,
+ buffer.Length,
+ resendDelayMs,
+ ackCallback);
+
+ if (!reliableDataPacketsSent.TryAdd(id, packet))
+ {
+ throw new Exception("That shouldn't be possible");
+ }
+ }
+
+ public static int ClampToInt(float value, int min, int max)
+ {
+ if (value < min) return min;
+ if (value > max) return max;
+ return (int)value;
+ }
+
+ /// <summary>
+ /// Sends the bytes reliably and stores the send.
+ /// </summary>
+ /// <param name="sendOption"></param>
+ /// <param name="data">The byte array to write to.</param>
+ /// <param name="ackCallback">The callback to make once the packet has been acknowledged.</param>
+ private void ReliableSend(byte sendOption, byte[] data, Action ackCallback = null)
+ {
+ //Inform keepalive not to send for a while
+ ResetKeepAliveTimer();
+
+ byte[] bytes = new byte[data.Length + 3];
+
+ //Add message type
+ bytes[0] = sendOption;
+
+ //Add reliable ID
+ AttachReliableID(bytes, 1, ackCallback);
+
+ //Copy data into new array
+ Buffer.BlockCopy(data, 0, bytes, bytes.Length - data.Length, data.Length);
+
+ //Write to connection
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogReliableSend(data.Length);
+ }
+
+ /// <summary>
+ /// Handles a reliable message being received and invokes the data event.
+ /// </summary>
+ /// <param name="message">The buffer received.</param>
+ private void ReliableMessageReceive(MessageReader message, int bytesReceived)
+ {
+ ushort id;
+ if (ProcessReliableReceive(message.Buffer, 1, out id))
+ {
+ InvokeDataReceived(SendOption.Reliable, message, 3, bytesReceived);
+ }
+ else
+ {
+ message.Recycle();
+ }
+
+ Statistics.LogReliableReceive(message.Length - 3, message.Length);
+ }
+
+ /// <summary>
+ /// Handles receives from reliable packets.
+ /// </summary>
+ /// <param name="bytes">The buffer containing the data.</param>
+ /// <param name="offset">The offset of the reliable header.</param>
+ /// <returns>Whether the packet was a new packet or not.</returns>
+ private bool ProcessReliableReceive(byte[] bytes, int offset, out ushort id)
+ {
+ byte b1 = bytes[offset];
+ byte b2 = bytes[offset + 1];
+
+ //Get the ID form the packet
+ id = (ushort)((b1 << 8) + b2);
+
+ /*
+ * It gets a little complicated here (note the fact I'm actually using a multiline comment for once...)
+ *
+ * In a simple world if our data is greater than the last reliable packet received (reliableReceiveLast)
+ * then it is guaranteed to be a new packet, if it's not we can see if we are missing that packet (lookup
+ * in reliableDataPacketsMissing).
+ *
+ * --------rrl############# (1)
+ *
+ * (where --- are packets received already and #### are packets that will be counted as new)
+ *
+ * Unfortunately if id becomes greater than 65535 it will loop back to zero so we will add a pointer that
+ * specifies any packets with an id behind it are also new (overwritePointer).
+ *
+ * ####op----------rrl##### (2)
+ *
+ * ------rll#########op---- (3)
+ *
+ * Anything behind than the reliableReceiveLast pointer (but greater than the overwritePointer is either a
+ * missing packet or something we've already received so when we change the pointers we need to make sure
+ * we keep note of what hasn't been received yet (reliableDataPacketsMissing).
+ *
+ * So...
+ */
+
+ bool result = true;
+
+ lock (reliableDataPacketsMissing)
+ {
+ //Calculate overwritePointer
+ ushort overwritePointer = (ushort)(reliableReceiveLast - 32768);
+
+ //Calculate if it is a new packet by examining if it is within the range
+ bool isNew;
+ if (overwritePointer < reliableReceiveLast)
+ isNew = id > reliableReceiveLast || id <= overwritePointer; //Figure (2)
+ else
+ isNew = id > reliableReceiveLast && id <= overwritePointer; //Figure (3)
+
+ //If it's new or we've not received anything yet
+ if (isNew)
+ {
+ // Mark items between the most recent receive and the id received as missing
+ if (id > reliableReceiveLast)
+ {
+ for (ushort i = (ushort)(reliableReceiveLast + 1); i < id; i++)
+ {
+ reliableDataPacketsMissing.Add(i);
+ }
+ }
+ else
+ {
+ int cnt = (ushort.MaxValue - reliableReceiveLast) + id;
+ for (ushort i = 1; i <= cnt; ++i)
+ {
+ reliableDataPacketsMissing.Add((ushort)(i + reliableReceiveLast));
+ }
+ }
+
+ //Update the most recently received
+ reliableReceiveLast = id;
+ }
+
+ //Else it could be a missing packet
+ else
+ {
+ //See if we're missing it, else this packet is a duplicate as so we return false
+ if (!reliableDataPacketsMissing.Remove(id))
+ {
+ result = false;
+ }
+ }
+ }
+
+ // Send an acknowledgement
+ SendAck(id);
+
+ return result;
+ }
+
+ /// <summary>
+ /// Handles acknowledgement packets to us.
+ /// </summary>
+ /// <param name="bytes">The buffer containing the data.</param>
+ private void AcknowledgementMessageReceive(byte[] bytes, int bytesReceived)
+ {
+ this.pingsSinceAck = 0;
+
+ ushort id = (ushort)((bytes[1] << 8) + bytes[2]);
+ AcknowledgeMessageId(id);
+
+ if (bytesReceived == 4)
+ {
+ byte recentPackets = bytes[3];
+ for (int i = 1; i <= 8; ++i)
+ {
+ if ((recentPackets & 1) != 0)
+ {
+ AcknowledgeMessageId((ushort)(id - i));
+ }
+
+ recentPackets >>= 1;
+ }
+ }
+
+ Statistics.LogAcknowledgementReceive(bytesReceived);
+ }
+
+ private void AcknowledgeMessageId(ushort id)
+ {
+ // Dispose of timer and remove from dictionary
+ if (reliableDataPacketsSent.TryRemove(id, out Packet packet))
+ {
+ this.Statistics.LogReliablePacketAcknowledged();
+ float rt = packet.Stopwatch.ElapsedMilliseconds;
+
+ packet.AckCallback?.Invoke();
+ packet.Recycle();
+
+ lock (PingLock)
+ {
+ this._pingMs = this._pingMs * .7f + rt * .3f;
+ }
+ }
+ else if (this.activePingPackets.TryRemove(id, out PingPacket pingPkt))
+ {
+ this.Statistics.LogReliablePacketAcknowledged();
+ float rt = pingPkt.Stopwatch.ElapsedMilliseconds;
+
+ pingPkt.Recycle();
+
+ lock (PingLock)
+ {
+ this._pingMs = this._pingMs * .7f + rt * .3f;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Sends an acknowledgement for a packet given its identification bytes.
+ /// </summary>
+ /// <param name="byte1">The first identification byte.</param>
+ /// <param name="byte2">The second identification byte.</param>
+ private void SendAck(ushort id)
+ {
+ byte recentPackets = 0;
+ lock (this.reliableDataPacketsMissing)
+ {
+ for (int i = 1; i <= 8; ++i)
+ {
+ if (!this.reliableDataPacketsMissing.Contains((ushort)(id - i)))
+ {
+ recentPackets |= (byte)(1 << (i - 1));
+ }
+ }
+ }
+
+ byte[] bytes = new byte[]
+ {
+ (byte)UdpSendOption.Acknowledgement,
+ (byte)(id >> 8),
+ (byte)(id >> 0),
+ recentPackets
+ };
+
+ try
+ {
+ WriteBytesToConnection(bytes, bytes.Length);
+ }
+ catch (InvalidOperationException) { }
+ }
+
+ private void DisposeReliablePackets()
+ {
+ foreach (var kvp in reliableDataPacketsSent)
+ {
+ if (this.reliableDataPacketsSent.TryRemove(kvp.Key, out var pkt))
+ {
+ pkt.Recycle();
+ }
+ }
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs
new file mode 100644
index 0000000..e64576a
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnection.cs
@@ -0,0 +1,259 @@
+using System;
+using System.Net.Sockets;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a connection that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc />
+ public abstract partial class UdpConnection : NetworkConnection
+ {
+ public static readonly byte[] EmptyDisconnectBytes = new byte[] { (byte)UdpSendOption.Disconnect };
+
+ public override float AveragePingMs => this._pingMs;
+ protected readonly ILogger logger;
+
+
+ public UdpConnection(ILogger logger) : base()
+ {
+ this.logger = logger;
+ this.PacketPool = new ObjectPool<Packet>(() => new Packet(this));
+ }
+
+ internal static Socket CreateSocket(IPMode ipMode)
+ {
+ Socket socket;
+ if (ipMode == IPMode.IPv4)
+ {
+ socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ }
+ else
+ {
+ if (!Socket.OSSupportsIPv6)
+ throw new InvalidOperationException("IPV6 not supported!");
+
+ socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp);
+ socket.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, false);
+ }
+
+ try
+ {
+ socket.DontFragment = false;
+ }
+ catch { }
+
+ try
+ {
+ const int SIO_UDP_CONNRESET = -1744830452;
+ socket.IOControl(SIO_UDP_CONNRESET, new byte[1], null);
+ }
+ catch { } // Only necessary on Windows
+
+ return socket;
+ }
+
+ /// <summary>
+ /// Writes the given bytes to the connection.
+ /// </summary>
+ /// <param name="bytes">The bytes to write.</param>
+ protected abstract void WriteBytesToConnection(byte[] bytes, int length);
+
+ /// <inheritdoc/>
+ public override SendErrors Send(MessageWriter msg)
+ {
+ if (this._state != ConnectionState.Connected)
+ {
+ return SendErrors.Disconnected;
+ }
+
+ try
+ {
+ byte[] buffer = new byte[msg.Length];
+ Buffer.BlockCopy(msg.Buffer, 0, buffer, 0, msg.Length);
+
+ switch (msg.SendOption)
+ {
+ case SendOption.Reliable:
+ ResetKeepAliveTimer();
+
+ AttachReliableID(buffer, 1);
+ WriteBytesToConnection(buffer, buffer.Length);
+ Statistics.LogReliableSend(buffer.Length - 3);
+ break;
+
+ default:
+ WriteBytesToConnection(buffer, buffer.Length);
+ Statistics.LogUnreliableSend(buffer.Length - 1);
+ break;
+ }
+ }
+ catch (Exception e)
+ {
+ this.logger?.WriteError("Unknown exception while sending: " + e);
+ return SendErrors.Unknown;
+ }
+
+ return SendErrors.None;
+ }
+
+ /// <summary>
+ /// Handles the reliable/fragmented sending from this connection.
+ /// </summary>
+ /// <param name="data">The data being sent.</param>
+ /// <param name="sendOption">The <see cref="SendOption"/> specified as its byte value.</param>
+ /// <param name="ackCallback">The callback to invoke when this packet is acknowledged.</param>
+ /// <returns>The bytes that should actually be sent.</returns>
+ protected virtual void HandleSend(byte[] data, byte sendOption, Action ackCallback = null)
+ {
+ switch (sendOption)
+ {
+ case (byte)UdpSendOption.Ping:
+ case (byte)SendOption.Reliable:
+ case (byte)UdpSendOption.Hello:
+ ReliableSend(sendOption, data, ackCallback);
+ break;
+
+ //Treat all else as unreliable
+ default:
+ UnreliableSend(sendOption, data);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Handles the receiving of data.
+ /// </summary>
+ /// <param name="message">The buffer containing the bytes received.</param>
+ protected internal virtual void HandleReceive(MessageReader message, int bytesReceived)
+ {
+ ushort id;
+ switch (message.Buffer[0])
+ {
+ //Handle reliable receives
+ case (byte)SendOption.Reliable:
+ ReliableMessageReceive(message, bytesReceived);
+ break;
+
+ //Handle acknowledgments
+ case (byte)UdpSendOption.Acknowledgement:
+ AcknowledgementMessageReceive(message.Buffer, bytesReceived);
+ message.Recycle();
+ break;
+
+ //We need to acknowledge hello and ping messages but dont want to invoke any events!
+ case (byte)UdpSendOption.Ping:
+ ProcessReliableReceive(message.Buffer, 1, out id);
+ Statistics.LogHelloReceive(bytesReceived);
+ message.Recycle();
+ break;
+ case (byte)UdpSendOption.Hello:
+ ProcessReliableReceive(message.Buffer, 1, out id);
+ Statistics.LogHelloReceive(bytesReceived);
+ message.Recycle();
+ break;
+
+ case (byte)UdpSendOption.Disconnect:
+ message.Offset = 1;
+ message.Position = 0;
+ DisconnectRemote("The remote sent a disconnect request", message);
+ message.Recycle();
+ break;
+
+ case (byte)SendOption.None:
+ InvokeDataReceived(SendOption.None, message, 1, bytesReceived);
+ Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived);
+ break;
+
+ // Treat everything else as garbage
+ default:
+ message.Recycle();
+
+ // TODO: A new stat for unused data
+ Statistics.LogUnreliableReceive(bytesReceived - 1, bytesReceived);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Sends bytes using the unreliable UDP protocol.
+ /// </summary>
+ /// <param name="sendOption">The SendOption to attach.</param>
+ /// <param name="data">The data.</param>
+ void UnreliableSend(byte sendOption, byte[] data)
+ {
+ this.UnreliableSend(sendOption, data, 0, data.Length);
+ }
+
+ /// <summary>
+ /// Sends bytes using the unreliable UDP protocol.
+ /// </summary>
+ /// <param name="data">The data.</param>
+ /// <param name="sendOption">The SendOption to attach.</param>
+ /// <param name="offset"></param>
+ /// <param name="length"></param>
+ void UnreliableSend(byte sendOption, byte[] data, int offset, int length)
+ {
+ byte[] bytes = new byte[length + 1];
+
+ //Add message type
+ bytes[0] = sendOption;
+
+ //Copy data into new array
+ Buffer.BlockCopy(data, offset, bytes, bytes.Length - length, length);
+
+ //Write to connection
+ WriteBytesToConnection(bytes, bytes.Length);
+
+ Statistics.LogUnreliableSend(length);
+ }
+
+ /// <summary>
+ /// Helper method to invoke the data received event.
+ /// </summary>
+ /// <param name="sendOption">The send option the message was received with.</param>
+ /// <param name="buffer">The buffer received.</param>
+ /// <param name="dataOffset">The offset of data in the buffer.</param>
+ void InvokeDataReceived(SendOption sendOption, MessageReader buffer, int dataOffset, int bytesReceived)
+ {
+ buffer.Offset = dataOffset;
+ buffer.Length = bytesReceived - dataOffset;
+ buffer.Position = 0;
+
+ InvokeDataReceived(buffer, sendOption);
+ }
+
+ /// <summary>
+ /// Sends a hello packet to the remote endpoint.
+ /// </summary>
+ /// <param name="acknowledgeCallback">The callback to invoke when the hello packet is acknowledged.</param>
+ protected void SendHello(byte[] bytes, Action acknowledgeCallback)
+ {
+ //First byte of handshake is version indicator so add data after
+ byte[] actualBytes;
+ if (bytes == null)
+ {
+ actualBytes = new byte[1];
+ }
+ else
+ {
+ actualBytes = new byte[bytes.Length + 1];
+ Buffer.BlockCopy(bytes, 0, actualBytes, 1, bytes.Length);
+ }
+
+ HandleSend(actualBytes, (byte)UdpSendOption.Hello, acknowledgeCallback);
+ }
+
+ /// <inheritdoc/>
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ DisposeKeepAliveTimer();
+ DisposeReliablePackets();
+ }
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs
new file mode 100644
index 0000000..c017a0f
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpConnectionListener.cs
@@ -0,0 +1,339 @@
+using System;
+using System.Collections.Concurrent;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Listens for new UDP connections and creates UdpConnections for them.
+ /// </summary>
+ /// <inheritdoc />
+ public class UdpConnectionListener : NetworkConnectionListener
+ {
+ private const int SendReceiveBufferSize = 1024 * 1024;
+ private const int BufferSize = ushort.MaxValue;
+
+ private Socket socket;
+ private ILogger Logger;
+ private Timer reliablePacketTimer;
+
+ private ConcurrentDictionary<EndPoint, UdpServerConnection> allConnections = new ConcurrentDictionary<EndPoint, UdpServerConnection>();
+
+ public override double AveragePing => this.allConnections.Values.Sum(c => c.AveragePingMs) / this.allConnections.Count;
+ public override int ConnectionCount { get { return this.allConnections.Count; } }
+ public override int ReceiveQueueLength => throw new NotImplementedException();
+ public override int SendQueueLength => throw new NotImplementedException();
+
+ /// <summary>
+ /// Creates a new UdpConnectionListener for the given <see cref="IPAddress"/>, port and <see cref="IPMode"/>.
+ /// </summary>
+ /// <param name="endPoint">The endpoint to listen on.</param>
+ public UdpConnectionListener(IPEndPoint endPoint, IPMode ipMode = IPMode.IPv4, ILogger logger = null)
+ {
+ this.Logger = logger;
+ this.EndPoint = endPoint;
+ this.IPMode = ipMode;
+
+ this.socket = UdpConnection.CreateSocket(this.IPMode);
+
+ socket.ReceiveBufferSize = SendReceiveBufferSize;
+ socket.SendBufferSize = SendReceiveBufferSize;
+
+ reliablePacketTimer = new Timer(ManageReliablePackets, null, 100, Timeout.Infinite);
+ }
+
+ ~UdpConnectionListener()
+ {
+ this.Dispose(false);
+ }
+
+ private void ManageReliablePackets(object state)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ var sock = kvp.Value;
+ sock.ManageReliablePackets();
+ }
+
+ try
+ {
+ this.reliablePacketTimer.Change(100, Timeout.Infinite);
+ }
+ catch { }
+ }
+
+ /// <inheritdoc />
+ public override void Start()
+ {
+ try
+ {
+ socket.Bind(EndPoint);
+ }
+ catch (SocketException e)
+ {
+ throw new HazelException("Could not start listening as a SocketException occurred", e);
+ }
+
+ StartListeningForData();
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ private void StartListeningForData()
+ {
+ EndPoint remoteEP = EndPoint;
+
+ MessageReader message = null;
+ try
+ {
+ message = MessageReader.GetSized(this.ReceiveBufferSize);
+ socket.BeginReceiveFrom(message.Buffer, 0, message.Buffer.Length, SocketFlags.None, ref remoteEP, ReadCallback, message);
+ }
+ catch (SocketException sx)
+ {
+ message?.Recycle();
+
+ this.Logger?.WriteError("Socket Ex in StartListening: " + sx.Message);
+
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+ catch (Exception ex)
+ {
+ message.Recycle();
+ this.Logger?.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+ }
+
+ void ReadCallback(IAsyncResult result)
+ {
+ var message = (MessageReader)result.AsyncState;
+ int bytesReceived;
+ EndPoint remoteEndPoint = new IPEndPoint(this.EndPoint.Address, this.EndPoint.Port);
+
+ //End the receive operation
+ try
+ {
+ bytesReceived = socket.EndReceiveFrom(result, ref remoteEndPoint);
+
+ message.Offset = 0;
+ message.Length = bytesReceived;
+ }
+ catch (ObjectDisposedException)
+ {
+ message.Recycle();
+ return;
+ }
+ catch (SocketException sx)
+ {
+ message.Recycle();
+ if (sx.SocketErrorCode == SocketError.NotConnected)
+ {
+ this.InvokeInternalError(HazelInternalErrors.ConnectionDisconnected);
+ return;
+ }
+
+ // Client no longer reachable, pretend it didn't happen
+ // TODO should this not inform the connection this client is lost???
+
+ // This thread suggests the IP is not passed out from WinSoc so maybe not possible
+ // http://stackoverflow.com/questions/2576926/python-socket-error-on-udp-data-receive-10054
+ this.Logger?.WriteError($"Socket Ex {sx.SocketErrorCode} in ReadCallback: {sx.Message}");
+
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+ catch (Exception ex)
+ {
+ // Idk, maybe a null ref after dispose?
+ message.Recycle();
+ this.Logger?.WriteError("Stopped due to: " + ex.Message);
+ return;
+ }
+
+ // I'm a little concerned about a infinite loop here, but it seems like it's possible
+ // to get 0 bytes read on UDP without the socket being shut down.
+ if (bytesReceived == 0)
+ {
+ message.Recycle();
+ this.Logger?.WriteInfo("Received 0 bytes");
+ Thread.Sleep(10);
+ StartListeningForData();
+ return;
+ }
+
+ //Begin receiving again
+ StartListeningForData();
+
+ bool aware = true;
+ bool isHello = message.Buffer[0] == (byte)UdpSendOption.Hello;
+
+ // If we're aware of this connection use the one already
+ // If this is a new client then connect with them!
+ UdpServerConnection connection;
+ if (!this.allConnections.TryGetValue(remoteEndPoint, out connection))
+ {
+ lock (this.allConnections)
+ {
+ if (!this.allConnections.TryGetValue(remoteEndPoint, out connection))
+ {
+ // Check for malformed connection attempts
+ if (!isHello)
+ {
+ message.Recycle();
+ return;
+ }
+
+ if (AcceptConnection != null)
+ {
+ if (!AcceptConnection((IPEndPoint)remoteEndPoint, message.Buffer, out var response))
+ {
+ message.Recycle();
+ if (response != null)
+ {
+ SendData(response, response.Length, remoteEndPoint);
+ }
+
+ return;
+ }
+ }
+
+ aware = false;
+ connection = new UdpServerConnection(this, (IPEndPoint)remoteEndPoint, this.IPMode, this.Logger);
+ if (!this.allConnections.TryAdd(remoteEndPoint, connection))
+ {
+ throw new HazelException("Failed to add a connection. This should never happen.");
+ }
+ }
+ }
+ }
+
+ // If it's a new connection invoke the NewConnection event.
+ // This needs to happen before handling the message because in localhost scenarios, the ACK and
+ // subsequent messages can happen before the NewConnection event sets up OnDataRecieved handlers
+ if (!aware)
+ {
+ // Skip header and hello byte;
+ message.Offset = 4;
+ message.Length = bytesReceived - 4;
+ message.Position = 0;
+ InvokeNewConnection(message, connection);
+ }
+
+ // Inform the connection of the buffer (new connections need to send an ack back to client)
+ connection.HandleReceive(message, bytesReceived);
+ }
+
+#if DEBUG
+ public int TestDropRate = -1;
+ private int dropCounter = 0;
+#endif
+
+ /// <summary>
+ /// Sends data from the listener socket.
+ /// </summary>
+ /// <param name="bytes">The bytes to send.</param>
+ /// <param name="endPoint">The endpoint to send to.</param>
+ internal void SendData(byte[] bytes, int length, EndPoint endPoint)
+ {
+ if (length > bytes.Length) return;
+
+#if DEBUG
+ if (TestDropRate > 0)
+ {
+ if (Interlocked.Increment(ref dropCounter) % TestDropRate == 0)
+ {
+ return;
+ }
+ }
+#endif
+
+ try
+ {
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ endPoint,
+ SendCallback,
+ null);
+
+ this.Statistics.AddBytesSent(length);
+ }
+ catch (SocketException e)
+ {
+ this.Logger?.WriteError("Could not send data as a SocketException occurred: " + e);
+ }
+ catch (ObjectDisposedException)
+ {
+ //Keep alive timer probably ran, ignore
+ return;
+ }
+ }
+
+ private void SendCallback(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Sends data from the listener socket.
+ /// </summary>
+ /// <param name="bytes">The bytes to send.</param>
+ /// <param name="endPoint">The endpoint to send to.</param>
+ internal void SendDataSync(byte[] bytes, int length, EndPoint endPoint)
+ {
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ endPoint
+ );
+
+ this.Statistics.AddBytesSent(length);
+ }
+ catch { }
+ }
+
+ /// <summary>
+ /// Removes a virtual connection from the list.
+ /// </summary>
+ /// <param name="endPoint">The endpoint of the virtual connection.</param>
+ internal void RemoveConnectionTo(EndPoint endPoint)
+ {
+ this.allConnections.TryRemove(endPoint, out var conn);
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ foreach (var kvp in this.allConnections)
+ {
+ kvp.Value.Dispose();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ this.reliablePacketTimer.Dispose();
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs
new file mode 100644
index 0000000..ff5b29d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UdpServerConnection.cs
@@ -0,0 +1,108 @@
+using System;
+using System.Net;
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Represents a servers's connection to a client that uses the UDP protocol.
+ /// </summary>
+ /// <inheritdoc/>
+ internal sealed class UdpServerConnection : UdpConnection
+ {
+ /// <summary>
+ /// The connection listener that we use the socket of.
+ /// </summary>
+ /// <remarks>
+ /// Udp server connections utilize the same socket in the listener for sends/receives, this is the listener that
+ /// created this connection and is hence the listener this conenction sends and receives via.
+ /// </remarks>
+ public UdpConnectionListener Listener { get; private set; }
+
+ /// <summary>
+ /// Creates a UdpConnection for the virtual connection to the endpoint.
+ /// </summary>
+ /// <param name="listener">The listener that created this connection.</param>
+ /// <param name="endPoint">The endpoint that we are connected to.</param>
+ /// <param name="IPMode">The IPMode we are connected using.</param>
+ internal UdpServerConnection(UdpConnectionListener listener, IPEndPoint endPoint, IPMode IPMode, ILogger logger)
+ : base(logger)
+ {
+ this.Listener = listener;
+ this.EndPoint = endPoint;
+ this.IPMode = IPMode;
+
+ State = ConnectionState.Connected;
+ this.InitializeKeepAliveTimer();
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+ this.Statistics.LogPacketSend(length);
+ Listener.SendData(bytes, length, EndPoint);
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <inheritdoc />
+ /// <remarks>
+ /// This will always throw a HazelException.
+ /// </remarks>
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ throw new InvalidOperationException("Cannot manually connect a UdpServerConnection, did you mean to use UdpClientConnection?");
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state != ConnectionState.Connected)
+ {
+ return false;
+ }
+
+ this._state = ConnectionState.NotConnected;
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ Listener.SendDataSync(bytes, bytes.Length, EndPoint);
+ }
+ catch { }
+
+ return true;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ Listener.RemoveConnectionTo(EndPoint);
+
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs
new file mode 100644
index 0000000..8e6063d
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel/Udp/UnityUdpClientConnection.cs
@@ -0,0 +1,353 @@
+using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading;
+
+
+namespace Hazel.Udp
+{
+ /// <summary>
+ /// Unity doesn't always get along with thread pools well, so this interface will hopefully suit that case better.
+ /// </summary>
+ /// <inheritdoc/>
+ public class UnityUdpClientConnection : UdpConnection
+ {
+ /// <summary>
+ /// The max size Hazel attempts to read from the network.
+ /// Defaults to 8096.
+ /// </summary>
+ /// <remarks>
+ /// 8096 is 5 times the standard modern MTU of 1500, so it's already too large imo.
+ /// If Hazel ever implements fragmented packets, then we might consider a larger value since combining 5
+ /// packets into 1 reader would be realistic and would cause reallocations. That said, Hazel is not meant
+ /// for transferring large contiguous blocks of data, so... please don't?
+ /// </remarks>
+ public int ReceiveBufferSize = 8096;
+
+ private Socket socket;
+
+ public UnityUdpClientConnection(ILogger logger, IPEndPoint remoteEndPoint, IPMode ipMode = IPMode.IPv4)
+ : base(logger)
+ {
+ this.EndPoint = remoteEndPoint;
+ this.IPMode = ipMode;
+
+ this.socket = CreateSocket(ipMode);
+ this.socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ExclusiveAddressUse, true);
+ }
+
+ ~UnityUdpClientConnection()
+ {
+ this.Dispose(false);
+ }
+
+ public void FixedUpdate()
+ {
+ try
+ {
+ ResendPacketsIfNeeded();
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("FixedUpdate: " + e);
+ }
+
+ try
+ {
+ ManageReliablePackets();
+ }
+ catch (Exception e)
+ {
+ this.logger.WriteError("FixedUpdate: " + e);
+ }
+ }
+
+ protected virtual void RestartConnection()
+ {
+ }
+
+ protected virtual void ResendPacketsIfNeeded()
+ {
+ }
+
+ /// <inheritdoc />
+ protected override void WriteBytesToConnection(byte[] bytes, int length)
+ {
+#if DEBUG
+ if (TestLagMs > 0)
+ {
+ ThreadPool.QueueUserWorkItem(a => { Thread.Sleep(this.TestLagMs); WriteBytesToConnectionReal(bytes, length); });
+ }
+ else
+#endif
+ {
+ WriteBytesToConnectionReal(bytes, length);
+ }
+ }
+
+ private void WriteBytesToConnectionReal(byte[] bytes, int length)
+ {
+ try
+ {
+ this.Statistics.LogPacketSend(length);
+ socket.BeginSendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint,
+ HandleSendTo,
+ null);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ /// <summary>
+ /// Synchronously writes the given bytes to the connection.
+ /// </summary>
+ /// <param name="bytes">The bytes to write.</param>
+ protected virtual void WriteBytesToConnectionSync(byte[] bytes, int length)
+ {
+ try
+ {
+ socket.SendTo(
+ bytes,
+ 0,
+ length,
+ SocketFlags.None,
+ EndPoint);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ private void HandleSendTo(IAsyncResult result)
+ {
+ try
+ {
+ socket.EndSendTo(result);
+ }
+ catch (NullReferenceException) { }
+ catch (ObjectDisposedException)
+ {
+ // Already disposed and disconnected...
+ }
+ catch (SocketException ex)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionSend, "Could not send data as a SocketException occurred: " + ex.Message);
+ }
+ }
+
+ public override void Connect(byte[] bytes = null, int timeout = 5000)
+ {
+ this.ConnectAsync(bytes);
+ for (int timer = 0; timer < timeout; timer += 100)
+ {
+ if (this.State != ConnectionState.Connecting) return;
+ Thread.Sleep(100);
+
+ // I guess if we're gonna block in Unity, then let's assume no one will pump this for us.
+ this.FixedUpdate();
+ }
+ }
+
+ /// <inheritdoc />
+ public override void ConnectAsync(byte[] bytes = null)
+ {
+ this.State = ConnectionState.Connecting;
+
+ try
+ {
+ if (IPMode == IPMode.IPv4)
+ socket.Bind(new IPEndPoint(IPAddress.Any, 0));
+ else
+ socket.Bind(new IPEndPoint(IPAddress.IPv6Any, 0));
+ }
+ catch (SocketException e)
+ {
+ this.State = ConnectionState.NotConnected;
+ throw new HazelException("A SocketException occurred while binding to the port.", e);
+ }
+
+ this.RestartConnection();
+
+ try
+ {
+ StartListeningForData();
+ }
+ catch (ObjectDisposedException)
+ {
+ // If the socket's been disposed then we can just end there but make sure we're in NotConnected state.
+ // If we end up here I'm really lost...
+ this.State = ConnectionState.NotConnected;
+ return;
+ }
+ catch (SocketException e)
+ {
+ Dispose();
+ throw new HazelException("A SocketException occurred while initiating a receive operation.", e);
+ }
+
+ // Write bytes to the server to tell it hi (and to punch a hole in our NAT, if present)
+ // When acknowledged set the state to connected
+ SendHello(bytes, () =>
+ {
+ this.InitializeKeepAliveTimer();
+ this.State = ConnectionState.Connected;
+ });
+ }
+
+ /// <summary>
+ /// Instructs the listener to begin listening.
+ /// </summary>
+ void StartListeningForData()
+ {
+ var msg = MessageReader.GetSized(this.ReceiveBufferSize);
+ try
+ {
+ EndPoint ep = this.EndPoint;
+ socket.BeginReceiveFrom(msg.Buffer, 0, msg.Buffer.Length, SocketFlags.None, ref ep, ReadCallback, msg);
+ }
+ catch
+ {
+ msg.Recycle();
+ this.Dispose();
+ }
+ }
+
+ /// <summary>
+ /// Called when data has been received by the socket.
+ /// </summary>
+ /// <param name="result">The asyncronous operation's result.</param>
+ void ReadCallback(IAsyncResult result)
+ {
+#if DEBUG
+ if (this.TestLagMs > 0)
+ {
+ Thread.Sleep(this.TestLagMs);
+ }
+#endif
+
+ var msg = (MessageReader)result.AsyncState;
+
+ try
+ {
+ EndPoint ep = this.EndPoint;
+ msg.Length = socket.EndReceiveFrom(result, ref ep);
+ }
+ catch (SocketException e)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception while reading data: " + e.Message);
+ return;
+ }
+ catch (ObjectDisposedException)
+ {
+ // Weirdly, it seems that this method can be called twice on the same AsyncState when object is disposed...
+ // So this just keeps us from hitting Duplicate Add errors at the risk of if this is a platform
+ // specific bug, we leak a MessageReader while the socket is disposing. Not a bad trade off.
+ return;
+ }
+ catch (Exception)
+ {
+ msg.Recycle();
+ return;
+ }
+
+ //Exit if no bytes read, we've failed.
+ if (msg.Length == 0)
+ {
+ msg.Recycle();
+ DisconnectInternal(HazelInternalErrors.ReceivedZeroBytes, "Received 0 bytes");
+ return;
+ }
+
+ //Begin receiving again
+ try
+ {
+ StartListeningForData();
+ }
+ catch (SocketException e)
+ {
+ DisconnectInternal(HazelInternalErrors.SocketExceptionReceive, "Socket exception during receive: " + e.Message);
+ }
+ catch (ObjectDisposedException)
+ {
+ //If the socket's been disposed then we can just end there.
+ return;
+ }
+
+#if DEBUG
+ if (this.TestDropRate > 0)
+ {
+ if ((this.testDropCount++ % this.TestDropRate) == 0)
+ {
+ return;
+ }
+ }
+#endif
+
+ HandleReceive(msg, msg.Length);
+ }
+
+ /// <summary>
+ /// Sends a disconnect message to the end point.
+ /// You may include optional disconnect data. The SendOption must be unreliable.
+ /// </summary>
+ protected override bool SendDisconnect(MessageWriter data = null)
+ {
+ lock (this)
+ {
+ if (this._state == ConnectionState.NotConnected) return false;
+ this._state = ConnectionState.NotConnected;
+ }
+
+ var bytes = EmptyDisconnectBytes;
+ if (data != null && data.Length > 0)
+ {
+ if (data.SendOption != SendOption.None) throw new ArgumentException("Disconnect messages can only be unreliable.");
+
+ bytes = data.ToByteArray(true);
+ bytes[0] = (byte)UdpSendOption.Disconnect;
+ }
+
+ try
+ {
+ this.WriteBytesToConnectionSync(bytes, bytes.Length);
+ }
+ catch { }
+
+ return true;
+ }
+
+ /// <inheritdoc />
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ SendDisconnect();
+ }
+
+ try { this.socket.Shutdown(SocketShutdown.Both); } catch { }
+ try { this.socket.Close(); } catch { }
+ try { this.socket.Dispose(); } catch { }
+
+ base.Dispose(disposing);
+ }
+ }
+}
diff --git a/Tools/Hazel-Networking/LICENSE b/Tools/Hazel-Networking/LICENSE
new file mode 100644
index 0000000..f3ad895
--- /dev/null
+++ b/Tools/Hazel-Networking/LICENSE
@@ -0,0 +1,22 @@
+The MIT License (MIT)
+
+Copyright (c) 2018 Innersloth LLC
+Copyright (c) 2016-2017 DarkRift Networking
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Tools/Hazel-Networking/README.md b/Tools/Hazel-Networking/README.md
new file mode 100644
index 0000000..5de4c0c
--- /dev/null
+++ b/Tools/Hazel-Networking/README.md
@@ -0,0 +1,47 @@
+#### Hazel Networking is a low-level networking library for C# providing connection-oriented, message based communication via UDP and RUDP.
+
+The aim of this fork is to create a simple interface for ultra-fast connection-based UDP communication for games. JamJar and I generally consider it the primary fork, but we goofed and didn't transfer the repo, so idk maybe someday.
+
+-----
+
+## Features
+- UDP and Reliable UDP.
+- Encrypted packets using DTLS
+- UDP Broadcast for local-multiplayer.
+- Completely thread safe.
+- All protocols are connection oriented (similar to TCP) and message based (similar to UDP)
+- IPv4 and IPv6 support
+- Automatic statistics about data passing in and out of connections
+- Designed to be as fast and lightweight as possible
+
+-----
+
+### This fork has been heavily modified from the original to reduce allocations, copies, and locking. It's pretty stable and Among Us uses it for all platforms, but still has the occasional issue.
+
+-----
+
+There is currently no online documentation. I might get around to it someday. I have changed some interfaces in "unintuitive ways", it is my hope that [this example repo](https://github.com/willardf/Hazel-Examples) will be able to help users get started.
+
+If you want to make improvements, I am open to pull requests. If you find bugs, feel free raise issues.
+
+-----
+
+## Building Hazel
+
+To build Hazel open [solution file](Hazel.sln) using your favourite C# IDE (I use Visual Studio 2019) and then build as you would any other project.
+
+-----
+## Tips with this fork
+
+ * Pay attention to which callbacks give you ownership of the MessageReader, making you responsible for recycling. In particular:
+ * You *should not* recycle messages after NewConnection events.
+ * You *should not* recycle messages after Disconnect events.
+ * You *should* recycle messages after DataReceived events.
+ * Hazel doesn't support fragmented packets. It used to, but I wasn't sure of it so I removed it and have never needed it since. Just stay under 1kb packets.
+
+## Tips for using Hazel with Unity
+
+ * Unity doesn't like other threads messing with GameObjects. This isn't a problem for tasks like relaying information. But for tasks like spawning GameObjects on clients or correcting physics, you will want to have a thread safe list of events that are run and cleared during Update or FixedUpdate.
+ * A List<T>+lock(object) is fine because you have many writers, one reader and Hazel doesn't guarantee event order.
+ * A ConcurrentBag is not a bad choice, but you will have to do something special to keep the Update method from hanging if you get an overwhelming number of new events (which suggests problems with your code elsewhere).
+ * I also recommend using the ConnectAsync method in a Coroutine that waits for State to change so you don't hang the game while connecting.