diff options
Diffstat (limited to 'Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs')
| -rw-r--r-- | Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs | 292 | 
1 files changed, 292 insertions, 0 deletions
diff --git a/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs new file mode 100644 index 0000000..584d08c --- /dev/null +++ b/Tools/Hazel-Networking/Hazel.UnitTests/SocketCapture.cs @@ -0,0 +1,292 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Threading; + +namespace Hazel.UnitTests +{ +    /// <summary> +    /// Acts as an intermediate between to sockets. +    ///  +    /// Use SendToLocalSemaphore and SendToRemoteSemaphore for +    /// explicit control of packet flow. +    /// </summary> +    public class SocketCapture : IDisposable +    { +        private IPEndPoint localEndPoint; +        private readonly IPEndPoint remoteEndPoint; + +        private Socket captureSocket; + +        private Thread receiveThread; +        private Thread forLocalThread; +        private Thread forRemoteThread; + +        private ILogger logger; + +        private readonly BlockingCollection<ByteSpan> forLocal = new BlockingCollection<ByteSpan>(); +        private readonly BlockingCollection<ByteSpan> forRemote = new BlockingCollection<ByteSpan>(); + +        /// <summary> +        /// Useful for debug logging, prefer <see cref="AssertPacketsToLocalCountEquals(int)"/> for assertions +        /// </summary> +        public int PacketsForLocalCount => this.forLocal.Count; + +        /// <summary> +        /// Useful for debug logging, prefer <see cref="AssertPacketsToRemoteCountEquals(int)"/> for assertions +        /// </summary> +        public int PacketsForRemoteCount => this.forRemote.Count; + +        public Semaphore SendToLocalSemaphore = null; +        public Semaphore SendToRemoteSemaphore = null; + +        private CancellationTokenSource cancellationSource = new CancellationTokenSource(); +        private readonly CancellationToken cancellationToken; + +        public SocketCapture(IPEndPoint captureEndpoint, IPEndPoint remoteEndPoint, ILogger logger = null) +        { +            this.logger = logger ?? new NullLogger(); +            this.cancellationToken = this.cancellationSource.Token; + +            this.remoteEndPoint = remoteEndPoint; + +            this.captureSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); +            this.captureSocket.Bind(captureEndpoint); + +            this.receiveThread = new Thread(this.ReceiveLoop); +            this.receiveThread.Start(); + +            this.forLocalThread = new Thread(this.SendToLocalLoop); +            this.forLocalThread.Start(); + +            this.forRemoteThread = new Thread(this.SendToRemoteLoop); +            this.forRemoteThread.Start(); +        } + +        public void Dispose() +        { +            if (this.cancellationSource != null) +            { +                this.cancellationSource.Cancel(); +                this.cancellationSource.Dispose(); +                this.cancellationSource = null; +            } + +            if (this.captureSocket != null) +            { +                this.captureSocket.Close(); +                this.captureSocket.Dispose(); +                this.captureSocket = null; +            } + +            if (this.receiveThread != null) +            { +                this.receiveThread.Join(); +                this.receiveThread = null; +            } + +            if (this.forLocalThread != null) +            { +                this.forLocalThread.Join(); +                this.forLocalThread = null; +            } + +            if (this.forRemoteThread != null) +            { +                this.forRemoteThread.Join(); +                this.forRemoteThread = null; +            } + +            GC.SuppressFinalize(this); +        } + +        private void ReceiveLoop() +        { +            try +            { +                IPEndPoint fromEndPoint = new IPEndPoint(IPAddress.Any, 0); + +                for (; ; ) +                { +                    byte[] buffer = new byte[2000]; +                    EndPoint endPoint = fromEndPoint; +                    int read = this.captureSocket.ReceiveFrom(buffer, ref endPoint); +                    if (read > 0) +                    { +                        // from the remote endpoint? +                        if (IPEndPoint.Equals(endPoint, remoteEndPoint)) +                        { +                            this.forLocal.Add(new ByteSpan(buffer, 0, read)); +                        } +                        else +                        { +                            this.localEndPoint = (IPEndPoint)endPoint; +                            this.forRemote.Add(new ByteSpan(buffer, 0, read)); +                        } +                    } +                } +            } +            catch (SocketException) +            { +            } +            finally +            { +                this.forLocal.CompleteAdding(); +                this.forRemote.CompleteAdding(); +            } +        } + +        private void SendToRemoteLoop() +        { +            while (!this.cancellationToken.IsCancellationRequested) +            { +                if (this.SendToRemoteSemaphore != null) +                { +                    if (!this.SendToRemoteSemaphore.WaitOne(100)) +                    { +                        continue; +                    } +                } + +                if (this.forRemote.TryTake(out var packet)) +                { +                    this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to remote"); +                    this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.remoteEndPoint); +                } +            } +        } + +        private void SendToLocalLoop() +        { +            while (!this.cancellationToken.IsCancellationRequested) +            { +                if (this.SendToLocalSemaphore != null) +                { +                    if (!this.SendToLocalSemaphore.WaitOne(100)) +                    { +                        continue; +                    } +                } + +                if (this.forLocal.TryTake(out var packet)) +                { +                    this.logger.WriteInfo($"Passed 1 packet of {packet.Length} bytes to local"); +                    this.captureSocket.SendTo(packet.GetUnderlyingArray(), packet.Offset, packet.Length, SocketFlags.None, this.localEndPoint); +                } +            } +        } + +        public void AssertPacketsToLocalCountEquals(int pktCnt) +        { +            DateTime start = DateTime.UtcNow; +            while (this.forLocal.Count != pktCnt) +            { +                if ((DateTime.UtcNow - start).TotalSeconds >= 5) +                { +                    Assert.AreEqual(pktCnt, this.forLocal.Count); +                } + +                Thread.Yield(); +            } +        } + +        public void AssertPacketsToRemoteCountEquals(int pktCnt) +        { +            DateTime start = DateTime.UtcNow; +            while (this.forRemote.Count != pktCnt) +            { +                if ((DateTime.UtcNow - start).TotalSeconds >= 5) +                { +                    Assert.AreEqual(pktCnt, this.forRemote.Count); +                } + +                Thread.Yield(); +            } +        } + +        public void DiscardPacketForLocal(int numToDiscard = 1) +        { +            for (int i = 0; i < numToDiscard; ++i) +            { +                this.forLocal.Take(); +            } +        } + +        public void DiscardPacketForRemote(int numToDiscard = 1) +        { +            for (int i = 0; i < numToDiscard; ++i) +            { +                this.forRemote.Take(); +            } +        } + +        public ByteSpan PeekPacketForLocal() +        { +            return this.forLocal.First(); +        } + +        public void ReorderPacketsForRemote(Action<List<ByteSpan>> reorderCallback) +        { +            List<ByteSpan> buffer = new List<ByteSpan>(); +            while (this.forRemote.TryTake(out var pkt)) buffer.Add(pkt); +            reorderCallback(buffer); +            foreach (var item in buffer) +            { +                this.forRemote.Add(item); +            } +        } + +        public void ReorderPacketsForLocal(Action<List<ByteSpan>> reorderCallback) +        { +            List<ByteSpan> buffer = new List<ByteSpan>(); +            while (this.forLocal.TryTake(out var pkt)) buffer.Add(pkt); +            reorderCallback(buffer); +            foreach (var item in buffer) +            { +                this.forLocal.Add(item); +            } +        } + +        internal void ReleasePacketsForRemote(int numToSend) +        { +            var newExpected = this.forRemote.Count - numToSend; +            this.SendToRemoteSemaphore.Release(numToSend); +            this.AssertPacketsToRemoteCountEquals(newExpected); +        } + +        internal void ReleasePacketsToLocal(int numToSend) +        { +            var newExpected = this.forLocal.Count - numToSend; +            this.SendToLocalSemaphore.Release(numToSend); +            this.AssertPacketsToLocalCountEquals(newExpected); +        } + +        internal string PacketsForRemoteToString() +        { +            StringBuilder sb = new StringBuilder(); +            foreach (var item in this.forRemote) +            { +                sb.AppendLine(item.ToString()); +            } + +            return sb.ToString(); +        } + +        internal string PacketsForLocalToString() +        { +            StringBuilder sb = new StringBuilder(); +            foreach(var item in this.forLocal) +            { +                sb.AppendLine(item.ToString()); +            } + +            return sb.ToString(); +        } +    } +}
\ No newline at end of file  | 
