aboutsummaryrefslogtreecommitdiff
path: root/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
diff options
context:
space:
mode:
Diffstat (limited to 'Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs')
-rw-r--r--Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs292
1 files changed, 292 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
new file mode 100644
index 0000000..584d08c
--- /dev/null
+++ b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs
@@ -0,0 +1,292 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.Sockets;
+using System.Text;
+using System.Threading;
+
+namespace Hazel.UnitTests
+{
+ /// <summary>
+ /// Acts as an intermediate between to sockets.
+ ///
+ /// Use SendToLocalSemaphore and SendToRemoteSemaphore for
+ /// explicit control of packet flow.
+ /// </summary>
+ public class SocketCapture : IDisposable
+ {
+ private IPEndPoint localEndPoint;
+ private readonly IPEndPoint remoteEndPoint;
+
+ private Socket captureSocket;
+
+ private Thread receiveThread;
+ private Thread forLocalThread;
+ private Thread forRemoteThread;
+
+ private ILogger logger;
+
+ private readonly BlockingCollection<ByteSpan> forLocal = new BlockingCollection<ByteSpan>();
+ private readonly BlockingCollection<ByteSpan> forRemote = new BlockingCollection<ByteSpan>();
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToLocalCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForLocalCount => this.forLocal.Count;
+
+ /// <summary>
+ /// Useful for debug logging, prefer <see cref="AssertPacketsToRemoteCountEquals(int)"/> for assertions
+ /// </summary>
+ public int PacketsForRemoteCount => this.forRemote.Count;
+
+ public Semaphore SendToLocalSemaphore = null;
+ public Semaphore SendToRemoteSemaphore = null;
+
+ private CancellationTokenSource cancellationSource = new CancellationTokenSource();
+ private readonly CancellationToken cancellationToken;
+
+ public SocketCapture(IPEndPoint captureEndpoint, IPEndPoint remoteEndPoint, ILogger logger = null)
+ {
+ this.logger = logger ?? new NullLogger();
+ this.cancellationToken = this.cancellationSource.Token;
+
+ this.remoteEndPoint = remoteEndPoint;
+
+ this.captureSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ this.captureSocket.Bind(captureEndpoint);
+
+ this.receiveThread = new Thread(this.ReceiveLoop);
+ this.receiveThread.Start();
+
+ this.forLocalThread = new Thread(this.SendToLocalLoop);
+ this.forLocalThread.Start();
+
+ this.forRemoteThread = new Thread(this.SendToRemoteLoop);
+ this.forRemoteThread.Start();
+ }
+
+ public void Dispose()
+ {
+ if (this.cancellationSource != null)
+ {
+ this.cancellationSource.Cancel();
+ this.cancellationSource.Dispose();
+ this.cancellationSource = null;
+ }
+
+ if (this.captureSocket != null)
+ {
+ this.captureSocket.Close();
+ this.captureSocket.Dispose();
+ this.captureSocket = null;
+ }
+
+ if (this.receiveThread != null)
+ {
+ this.receiveThread.Join();
+ this.receiveThread = null;
+ }
+
+ if (this.forLocalThread != null)
+ {
+ this.forLocalThread.Join();
+ this.forLocalThread = null;
+ }
+
+ if (this.forRemoteThread != null)
+ {
+ this.forRemoteThread.Join();
+ this.forRemoteThread = null;
+ }
+
+ GC.SuppressFinalize(this);
+ }
+
+ private void ReceiveLoop()
+ {
+ try
+ {
+ IPEndPoint fromEndPoint = new IPEndPoint(IPAddress.Any, 0);
+
+ for (; ; )
+ {
+ byte[] buffer = new byte[2000];
+ EndPoint endPoint = fromEndPoint;
+ int read = this.captureSocket.ReceiveFrom(buffer, ref endPoint);
+ if (read > 0)
+ {
+ // from the remote endpoint?
+ if (IPEndPoint.Equals(endPoint, remoteEndPoint))
+ {
+ this.forLocal.Add(new ByteSpan(buffer, 0, read));
+ }
+ else
+ {
+ this.localEndPoint = (IPEndPoint)endPoint;
+ this.forRemote.Add(new ByteSpan(buffer, 0, read));
+ }
+ }
+ }
+ }
+ catch (SocketException)
+ {
+ }
+ finally
+ {
+ this.forLocal.CompleteAdding();
+ this.forRemote.CompleteAdding();
+ }
+ }
+
+ private void SendToRemoteLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToRemoteSemaphore != null)
+ {
+ if (!this.SendToRemoteSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forRemote.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to remote");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.remoteEndPoint);
+ }
+ }
+ }
+
+ private void SendToLocalLoop()
+ {
+ while (!this.cancellationToken.IsCancellationRequested)
+ {
+ if (this.SendToLocalSemaphore != null)
+ {
+ if (!this.SendToLocalSemaphore.WaitOne(100))
+ {
+ continue;
+ }
+ }
+
+ if (this.forLocal.TryTake(out var packet))
+ {
+ this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to local");
+ this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.localEndPoint);
+ }
+ }
+ }
+
+ public void AssertPacketsToLocalCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forLocal.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forLocal.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void AssertPacketsToRemoteCountEquals(int pktCnt)
+ {
+ DateTime start = DateTime.UtcNow;
+ while (this.forRemote.Count != pktCnt)
+ {
+ if ((DateTime.UtcNow - start).TotalSeconds >= 5)
+ {
+ Assert.AreEqual(pktCnt, this.forRemote.Count);
+ }
+
+ Thread.Yield();
+ }
+ }
+
+ public void DiscardPacketForLocal(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forLocal.Take();
+ }
+ }
+
+ public void DiscardPacketForRemote(int numToDiscard = 1)
+ {
+ for (int i = 0; i < numToDiscard; ++i)
+ {
+ this.forRemote.Take();
+ }
+ }
+
+ public ByteSpan PeekPacketForLocal()
+ {
+ return this.forLocal.First();
+ }
+
+ public void ReorderPacketsForRemote(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forRemote.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forRemote.Add(item);
+ }
+ }
+
+ public void ReorderPacketsForLocal(Action<List<ByteSpan>> reorderCallback)
+ {
+ List<ByteSpan> buffer = new List<ByteSpan>();
+ while (this.forLocal.TryTake(out var pkt)) buffer.Add(pkt);
+ reorderCallback(buffer);
+ foreach (var item in buffer)
+ {
+ this.forLocal.Add(item);
+ }
+ }
+
+ internal void ReleasePacketsForRemote(int numToSend)
+ {
+ var newExpected = this.forRemote.Count - numToSend;
+ this.SendToRemoteSemaphore.Release(numToSend);
+ this.AssertPacketsToRemoteCountEquals(newExpected);
+ }
+
+ internal void ReleasePacketsToLocal(int numToSend)
+ {
+ var newExpected = this.forLocal.Count - numToSend;
+ this.SendToLocalSemaphore.Release(numToSend);
+ this.AssertPacketsToLocalCountEquals(newExpected);
+ }
+
+ internal string PacketsForRemoteToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach (var item in this.forRemote)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+
+ internal string PacketsForLocalToString()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach(var item in this.forLocal)
+ {
+ sb.AppendLine(item.ToString());
+ }
+
+ return sb.ToString();
+ }
+ }
+} \ No newline at end of file