using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
namespace Hazel.UnitTests
{
///
/// Acts as an intermediate between to sockets.
///
/// Use SendToLocalSemaphore and SendToRemoteSemaphore for
/// explicit control of packet flow.
///
public class SocketCapture : IDisposable
{
private IPEndPoint localEndPoint;
private readonly IPEndPoint remoteEndPoint;
private Socket captureSocket;
private Thread receiveThread;
private Thread forLocalThread;
private Thread forRemoteThread;
private ILogger logger;
private readonly BlockingCollection forLocal = new BlockingCollection();
private readonly BlockingCollection forRemote = new BlockingCollection();
///
/// Useful for debug logging, prefer for assertions
///
public int PacketsForLocalCount => this.forLocal.Count;
///
/// Useful for debug logging, prefer for assertions
///
public int PacketsForRemoteCount => this.forRemote.Count;
public Semaphore SendToLocalSemaphore = null;
public Semaphore SendToRemoteSemaphore = null;
private CancellationTokenSource cancellationSource = new CancellationTokenSource();
private readonly CancellationToken cancellationToken;
public SocketCapture(IPEndPoint captureEndpoint, IPEndPoint remoteEndPoint, ILogger logger = null)
{
this.logger = logger ?? new NullLogger();
this.cancellationToken = this.cancellationSource.Token;
this.remoteEndPoint = remoteEndPoint;
this.captureSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
this.captureSocket.Bind(captureEndpoint);
this.receiveThread = new Thread(this.ReceiveLoop);
this.receiveThread.Start();
this.forLocalThread = new Thread(this.SendToLocalLoop);
this.forLocalThread.Start();
this.forRemoteThread = new Thread(this.SendToRemoteLoop);
this.forRemoteThread.Start();
}
public void Dispose()
{
if (this.cancellationSource != null)
{
this.cancellationSource.Cancel();
this.cancellationSource.Dispose();
this.cancellationSource = null;
}
if (this.captureSocket != null)
{
this.captureSocket.Close();
this.captureSocket.Dispose();
this.captureSocket = null;
}
if (this.receiveThread != null)
{
this.receiveThread.Join();
this.receiveThread = null;
}
if (this.forLocalThread != null)
{
this.forLocalThread.Join();
this.forLocalThread = null;
}
if (this.forRemoteThread != null)
{
this.forRemoteThread.Join();
this.forRemoteThread = null;
}
GC.SuppressFinalize(this);
}
private void ReceiveLoop()
{
try
{
IPEndPoint fromEndPoint = new IPEndPoint(IPAddress.Any, 0);
for (; ; )
{
byte[] buffer = new byte[2000];
EndPoint endPoint = fromEndPoint;
int read = this.captureSocket.ReceiveFrom(buffer, ref endPoint);
if (read > 0)
{
// from the remote endpoint?
if (IPEndPoint.Equals(endPoint, remoteEndPoint))
{
this.forLocal.Add(new ByteSpan(buffer, 0, read));
}
else
{
this.localEndPoint = (IPEndPoint)endPoint;
this.forRemote.Add(new ByteSpan(buffer, 0, read));
}
}
}
}
catch (SocketException)
{
}
finally
{
this.forLocal.CompleteAdding();
this.forRemote.CompleteAdding();
}
}
private void SendToRemoteLoop()
{
while (!this.cancellationToken.IsCancellationRequested)
{
if (this.SendToRemoteSemaphore != null)
{
if (!this.SendToRemoteSemaphore.WaitOne(100))
{
continue;
}
}
if (this.forRemote.TryTake(out var packet))
{
this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to remote");
this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.remoteEndPoint);
}
}
}
private void SendToLocalLoop()
{
while (!this.cancellationToken.IsCancellationRequested)
{
if (this.SendToLocalSemaphore != null)
{
if (!this.SendToLocalSemaphore.WaitOne(100))
{
continue;
}
}
if (this.forLocal.TryTake(out var packet))
{
this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to local");
this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.localEndPoint);
}
}
}
public void AssertPacketsToLocalCountEquals(int pktCnt)
{
DateTime start = DateTime.UtcNow;
while (this.forLocal.Count != pktCnt)
{
if ((DateTime.UtcNow - start).TotalSeconds >= 5)
{
Assert.AreEqual(pktCnt, this.forLocal.Count);
}
Thread.Yield();
}
}
public void AssertPacketsToRemoteCountEquals(int pktCnt)
{
DateTime start = DateTime.UtcNow;
while (this.forRemote.Count != pktCnt)
{
if ((DateTime.UtcNow - start).TotalSeconds >= 5)
{
Assert.AreEqual(pktCnt, this.forRemote.Count);
}
Thread.Yield();
}
}
public void DiscardPacketForLocal(int numToDiscard = 1)
{
for (int i = 0; i < numToDiscard; ++i)
{
this.forLocal.Take();
}
}
public void DiscardPacketForRemote(int numToDiscard = 1)
{
for (int i = 0; i < numToDiscard; ++i)
{
this.forRemote.Take();
}
}
public ByteSpan PeekPacketForLocal()
{
return this.forLocal.First();
}
public void ReorderPacketsForRemote(Action> reorderCallback)
{
List buffer = new List();
while (this.forRemote.TryTake(out var pkt)) buffer.Add(pkt);
reorderCallback(buffer);
foreach (var item in buffer)
{
this.forRemote.Add(item);
}
}
public void ReorderPacketsForLocal(Action> reorderCallback)
{
List buffer = new List();
while (this.forLocal.TryTake(out var pkt)) buffer.Add(pkt);
reorderCallback(buffer);
foreach (var item in buffer)
{
this.forLocal.Add(item);
}
}
internal void ReleasePacketsForRemote(int numToSend)
{
var newExpected = this.forRemote.Count - numToSend;
this.SendToRemoteSemaphore.Release(numToSend);
this.AssertPacketsToRemoteCountEquals(newExpected);
}
internal void ReleasePacketsToLocal(int numToSend)
{
var newExpected = this.forLocal.Count - numToSend;
this.SendToLocalSemaphore.Release(numToSend);
this.AssertPacketsToLocalCountEquals(newExpected);
}
internal string PacketsForRemoteToString()
{
StringBuilder sb = new StringBuilder();
foreach (var item in this.forRemote)
{
sb.AppendLine(item.ToString());
}
return sb.ToString();
}
internal string PacketsForLocalToString()
{
StringBuilder sb = new StringBuilder();
foreach(var item in this.forLocal)
{
sb.AppendLine(item.ToString());
}
return sb.ToString();
}
}
}