diff options
author | chai <215380520@qq.com> | 2023-10-14 12:43:01 +0800 |
---|---|---|
committer | chai <215380520@qq.com> | 2023-10-14 12:43:01 +0800 |
commit | 1931354421132dbe9b7a44be0219c3e33be831fd (patch) | |
tree | af16b9b3a905d28a2ac57490dc62353623245130 /Projects | |
parent | 949f5accd281a051e5e2eeb3df8767b4794b6a1e (diff) |
+ Message UnitTest
Diffstat (limited to 'Projects')
-rw-r--r-- | Projects/Message/Message/Message.csproj | 1 | ||||
-rw-r--r-- | Projects/Message/Message/MessageReader.cs | 33 | ||||
-rw-r--r-- | Projects/Message/Message/MessageReaderTests.cs | 683 | ||||
-rw-r--r-- | Projects/Message/Message/MessageWriter.cs | 7 |
4 files changed, 715 insertions, 9 deletions
diff --git a/Projects/Message/Message/Message.csproj b/Projects/Message/Message/Message.csproj index a15b051..56698ef 100644 --- a/Projects/Message/Message/Message.csproj +++ b/Projects/Message/Message/Message.csproj @@ -53,6 +53,7 @@ <ItemGroup> <Compile Include="IRecyclable.cs" /> <Compile Include="MessageReader.cs" /> + <Compile Include="MessageReaderTests.cs" /> <Compile Include="MessageWriter.cs" /> <Compile Include="MessageWriterTests.cs" /> <Compile Include="ObjectPool.cs" /> diff --git a/Projects/Message/Message/MessageReader.cs b/Projects/Message/Message/MessageReader.cs index b49f9e2..32b0218 100644 --- a/Projects/Message/Message/MessageReader.cs +++ b/Projects/Message/Message/MessageReader.cs @@ -88,7 +88,7 @@ namespace MultiplayerToolkit } /// <summary> - /// 将缓冲区封装为根MessageReader,不拷贝 + /// 将缓冲区封装为根消息,不拷贝 /// </summary> /// <param name="buffer"></param> /// <returns></returns> @@ -106,7 +106,7 @@ namespace MultiplayerToolkit } /// <summary> - /// 拷贝buffer,封装为根MessageReader + /// 将缓冲区封装为根消息,拷贝 /// </summary> /// <param name="buffer"></param> /// <returns></returns> @@ -169,9 +169,25 @@ namespace MultiplayerToolkit return output; } + + public static MessageReader Copy(MessageReader source) + { + var output = MessageReader.GetSized(source.Buffer.Length); + System.Buffer.BlockCopy(source.Buffer, 0, output.Buffer, 0, source.Buffer.Length); + + output.Offset = source.Offset; + + output._position = source._position; + output.readHead = source.readHead; + + output.Length = source.Length; + + return output; + } + #endif -#endregion + #endregion /// <summary> /// 返回一个可读子消息,引用了父缓冲区,**不要回收** @@ -179,7 +195,7 @@ namespace MultiplayerToolkit public MessageReader ReadMessage() { // 至少有一个length - if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}"); + if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 2 of {this.BytesRemaining}"); MessageReader output = new MessageReader(); @@ -207,7 +223,7 @@ namespace MultiplayerToolkit /// </summary> public MessageReader ReadMessageAsNewBuffer() { - if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}"); + if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 2 of {this.BytesRemaining}"); var len = this.ReadUInt16(); // Position += 2 @@ -218,7 +234,6 @@ namespace MultiplayerToolkit Array.Copy(this.Buffer, this.readHead, output.Buffer, 0, len); output.Length = len; - //output.Tag = tag; this.Position += output.Length; return output; @@ -244,7 +259,7 @@ namespace MultiplayerToolkit { var headerOffset = reader.Offset - 2; var endOfMessage = reader.Offset + reader.Length; // 有效部分的后一个byte - var len = this.TotalLength - endOfMessage; + var len = reader.Buffer.Length - endOfMessage; temp = MessageReader.GetSized(len); Array.Copy(reader.Buffer, endOfMessage, temp.Buffer, 0, len); @@ -271,8 +286,8 @@ namespace MultiplayerToolkit var temp = MessageReader.GetSized(reader.Buffer.Length); try { - var headerOffset = reader.Offset - 3; // headerOffset是length+tag,这个方法仅仅接受reader不含sendoption的情况 - var startOfMessage = reader.Offset; // 头部后面的数据开始的索引 + var headerOffset = reader.Offset - 2; + var startOfMessage = reader.Offset; var len = reader.Buffer.Length - headerOffset; // 疑似写错了,应该是headerOffset int writerOffset = 0; diff --git a/Projects/Message/Message/MessageReaderTests.cs b/Projects/Message/Message/MessageReaderTests.cs new file mode 100644 index 0000000..d6f39d2 --- /dev/null +++ b/Projects/Message/Message/MessageReaderTests.cs @@ -0,0 +1,683 @@ +using System; +using System.IO; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MultiplayerToolkit; + +namespace MultiplayerToolkit +{ + [TestClass] + public class MessageReaderTests + { + [TestMethod] + public void ReadProperInt() + { + const int Test1 = int.MaxValue; + const int Test2 = int.MinValue; + + var msg = new MessageWriter(128); + msg.StartMessage(1); + msg.Write(Test1); + msg.Write(Test2); + msg.EndMessage(); + + Assert.AreEqual(10, msg.Length); + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + Assert.AreEqual(Test1, reader.ReadInt32()); + Assert.AreEqual(Test2, reader.ReadInt32()); + } + + [TestMethod] + public void ReadProperBool() + { + const bool Test1 = true; + const bool Test2 = false; + + var msg = new MessageWriter(128); + msg.StartMessage(1); + msg.Write(Test1); + msg.Write(Test2); + msg.EndMessage(); + + Assert.AreEqual(4, msg.Length); + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + + Assert.AreEqual(Test1, reader.ReadBoolean()); + Assert.AreEqual(Test2, reader.ReadBoolean()); + + } + + [TestMethod] + public void ReadProperString() + { + const string Test1 = "Hello"; + string Test2 = new string(' ', 1024); + var msg = new MessageWriter(2048); + msg.StartMessage(1); + msg.Write(Test1); + msg.Write(Test2); + msg.Write(string.Empty); + msg.EndMessage(); + + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + + Assert.AreEqual(Test1, reader.ReadString()); + Assert.AreEqual(Test2, reader.ReadString()); + Assert.AreEqual(string.Empty, reader.ReadString()); + + } + + [TestMethod] + public void ReadProperFloat() + { + const float Test1 = 12.34f; + + var msg = new MessageWriter(2048); + msg.StartMessage(1); + msg.Write(Test1); + msg.EndMessage(); + + Assert.AreEqual(6, msg.Length); + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + + Assert.AreEqual(Test1, reader.ReadSingle()); + } + + [TestMethod] + public void RemoveMessageWorks() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + reader.Length = msg.Length; + + var zero = reader.ReadMessage(); + + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + two.RemoveMessage(three); + + // Reader becomes invalid + Assert.AreNotEqual(Test3, three.ReadByte()); + + // Unrealistic, but nice. Earlier data is not affected + Assert.AreEqual(Test0, zero.ReadByte()); + + // Continuing to read depth-first works + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = reader.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + } + + [TestMethod] + public void InsertMessageWorks() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + const byte TestInsert = 66; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + + MessageWriter writer = MessageWriter.Get(); + writer.StartMessage(5); + writer.Write(TestInsert); + writer.EndMessage(); + + reader.ReadMessage(); + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + + two.InsertMessage(three, writer); + + //set the position back to zero to read back the updated message + reader.Position = 0; + + var zero = reader.ReadMessage(); + Assert.AreEqual(Test0, zero.ReadByte()); + one = reader.ReadMessage(); + two = one.ReadMessage(); + var insert = two.ReadMessage(); + Assert.AreEqual(TestInsert, insert.ReadByte()); + three = two.ReadMessage(); + Assert.AreEqual(Test3, three.ReadByte()); + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = reader.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + } + + [TestMethod] + public void InsertMessageWorksWithSendOptionNone() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + const byte TestInsert = 66; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + + MessageWriter writer = MessageWriter.Get(); + writer.StartMessage(5); + writer.Write(TestInsert); + writer.EndMessage(); + + reader.ReadMessage(); + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + + two.InsertMessage(three, writer); + + //set the position back to zero to read back the updated message + reader.Position = 0; + + var zero = reader.ReadMessage(); + Assert.AreEqual(Test0, zero.ReadByte()); + one = reader.ReadMessage(); + two = one.ReadMessage(); + var insert = two.ReadMessage(); + Assert.AreEqual(TestInsert, insert.ReadByte()); + three = two.ReadMessage(); + Assert.AreEqual(Test3, three.ReadByte()); + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = reader.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + + } + + [TestMethod] + public void InsertMessageWithoutStartMessageInWriter() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + const byte TestInsert = 66; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + + MessageWriter writer = MessageWriter.Get(); + writer.Write(TestInsert); + + reader.ReadMessage(); + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + + two.InsertMessage(three, writer); + + //set the position back to zero to read back the updated message + reader.Position = 0; + + var zero = reader.ReadMessage(); + Assert.AreEqual(Test0, zero.ReadByte()); + one = reader.ReadMessage(); + two = one.ReadMessage(); + Assert.AreEqual(TestInsert, two.ReadByte()); + three = two.ReadMessage(); + Assert.AreEqual(Test3, three.ReadByte()); + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = reader.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + } + + [TestMethod] + public void InsertMessageWithMultipleMessagesInWriter() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + const byte TestInsert = 66; + const byte TestInsert2 = 77; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + + MessageWriter writer = MessageWriter.Get(); + writer.StartMessage(5); + writer.Write(TestInsert); + writer.EndMessage(); + + writer.StartMessage(6); + writer.Write(TestInsert2); + writer.EndMessage(); + + reader.ReadMessage(); + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + + two.InsertMessage(three, writer); + + //set the position back to zero to read back the updated message + reader.Position = 0; + + var zero = reader.ReadMessage(); + Assert.AreEqual(Test0, zero.ReadByte()); + one = reader.ReadMessage(); + two = one.ReadMessage(); + var insert = two.ReadMessage(); + Assert.AreEqual(TestInsert, insert.ReadByte()); + var insert2 = two.ReadMessage(); + Assert.AreEqual(TestInsert2, insert2.ReadByte()); + three = two.ReadMessage(); + Assert.AreEqual(Test3, three.ReadByte()); + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = reader.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + } + + [TestMethod] + public void InsertMessageMultipleInsertsWithoutReset() + { + const byte Test0 = 11; + const byte Test3 = 33; + const byte Test4 = 44; + const byte Test5 = 55; + const byte Test6 = 66; + const byte TestInsert = 77; + const byte TestInsert2 = 88; + + var msg = new MessageWriter(2048); + msg.StartMessage(0); + msg.Write(Test0); + msg.EndMessage(); + + msg.StartMessage(12); + msg.StartMessage(23); + + msg.StartMessage(34); + msg.Write(Test3); + msg.EndMessage(); + + msg.StartMessage(45); + msg.Write(Test4); + msg.EndMessage(); + + msg.EndMessage(); + + msg.StartMessage(56); + msg.Write(Test5); + msg.EndMessage(); + + msg.EndMessage(); + + msg.StartMessage(67); + msg.Write(Test6); + msg.EndMessage(); + + MessageReader reader = MessageReader.Get(msg.Buffer); + + MessageWriter writer = MessageWriter.Get(); + writer.StartMessage(5); + writer.Write(TestInsert); + writer.EndMessage(); + + MessageWriter writer2 = MessageWriter.Get(); + writer2.StartMessage(6); + writer2.Write(TestInsert2); + writer2.EndMessage(); + + reader.ReadMessage(); + var one = reader.ReadMessage(); + var two = one.ReadMessage(); + var three = two.ReadMessage(); + + two.InsertMessage(three, writer); + + // three becomes invalid + Assert.AreNotEqual(Test3, three.ReadByte()); + + // Continuing to read works + var four = two.ReadMessage(); + Assert.AreEqual(Test4, four.ReadByte()); + + var five = one.ReadMessage(); + Assert.AreEqual(Test5, five.ReadByte()); + + reader.InsertMessage(one, writer2); + + var six = reader.ReadMessage(); + Assert.AreEqual(Test6, six.ReadByte()); + } + + private void SetZero(MessageReader reader) + { + for (int i = 0; i < reader.Buffer.Length; ++i) + reader.Buffer[i] = 0; + } + + [TestMethod] + public void CopySubMessage() + { + const byte Test1 = 12; + const byte Test2 = 146; + + var msg = new MessageWriter(2048); + msg.StartMessage(1); + + msg.StartMessage(2); + msg.Write(Test1); + msg.Write(Test2); + msg.EndMessage(); + + msg.EndMessage(); + + MessageReader handleMessage = MessageReader.GetChildMessage(msg.Buffer, 0); + + var parentReader = MessageReader.Copy(handleMessage); + + handleMessage.Recycle(); + SetZero(handleMessage); + + for (int i = 0; i < 5; ++i) + { + + var reader = parentReader.ReadMessage(); + Assert.AreEqual(Test1, reader.ReadByte()); + Assert.AreEqual(Test2, reader.ReadByte()); + + var temp = parentReader; + parentReader = MessageReader.CopyMessageIntoParent(reader); + + temp.Recycle(); + SetZero(temp); + SetZero(reader); + } + } + + + [TestMethod] + public void ReadMessageLength() + { + var msg = new MessageWriter(2048); + msg.StartMessage(1); + msg.Write(65534); + msg.StartMessage(2); + msg.Write("HO"); + msg.EndMessage(); + msg.StartMessage(2); + msg.EndMessage(); + msg.EndMessage(); + + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + Assert.AreEqual(65534, reader.ReadInt32()); // Content + + var sub = reader.ReadMessage(); + Assert.AreEqual(3, sub.Length); + Assert.AreEqual("HO", sub.ReadString()); + + sub = reader.ReadMessage(); + Assert.AreEqual(0, sub.Length); + } + + + [TestMethod] + public void ReadMessageAsNewBufferLength() + { + var msg = new MessageWriter(2048); + msg.StartMessage(1); + msg.Write(65534); + msg.StartMessage(2); + msg.Write("HO"); + msg.EndMessage(); + msg.StartMessage(232); + msg.EndMessage(); + msg.EndMessage(); + + Assert.AreEqual(msg.Length, msg.Position); + + MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0); + Assert.AreEqual(65534, reader.ReadInt32()); // Content + + var sub = reader.ReadMessageAsNewBuffer(); + Assert.AreEqual(0, sub.Position); + Assert.AreEqual(0, sub.Offset); + + Assert.AreEqual(3, sub.Length); + Assert.AreEqual("HO", sub.ReadString()); + + sub.Recycle(); + + sub = reader.ReadMessageAsNewBuffer(); + Assert.AreEqual(0, sub.Position); + Assert.AreEqual(0, sub.Offset); + + Assert.AreEqual(0, sub.Length); + sub.Recycle(); + } + + + [TestMethod] + public void ReadStringProtectsAgainstOverrun() + { + const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data"; + + // An extra byte from the length of TestData when written via MessageWriter + int DataLength = TestDataFromAPreviousPacket.Length + 1; + + // THE BUG + // + // No bound checks. When the server wants to read a string from + // an offset, it reads the packed int at that offset, treats it + // as a length and then proceeds to read the data that comes after + // it without any bound checks. This can be chained with something + // else to create an infoleak. + + MessageWriter writer = MessageWriter.Get(); + + // This will be our malicious "string length" + writer.WritePacked(DataLength); + + // This is data from a "previous packet" + writer.Write(TestDataFromAPreviousPacket); + + byte[] testData = writer.ToByteArray(); + + // One extra byte for the MessageWriter header, one more for the malicious data + Assert.AreEqual(DataLength + 1, testData.Length); + + var dut = MessageReader.Get(testData); + + // If Length is short by even a byte, ReadString should obey that. + dut.Length--; + + try + { + dut.ReadString(); + Assert.Fail("ReadString is expected to throw"); + } + catch (InvalidDataException) { } + } + + [TestMethod] + public void ReadMessageProtectsAgainstOverrun() + { + const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data"; + + // An extra byte from the length of TestData when written via MessageWriter + // Extra 3 bytes for the length header for ReadMessage. + int DataLength = TestDataFromAPreviousPacket.Length + 1 + 2; + + // THE BUG + // + // No bound checks. When the server wants to read a message, it + // reads the uint16 at that offset, treats it as a length without any bound checks. + // This can be allow a later ReadString or ReadBytes to create an infoleak. + + MessageWriter writer = MessageWriter.Get(); + + // This is the malicious length. No data in this message, so it should be zero. + writer.Write((ushort)1); // length + + // This is data from a "previous packet" + writer.Write(TestDataFromAPreviousPacket); + + byte[] testData = writer.ToByteArray(); + + Assert.AreEqual(DataLength, testData.Length); + + var outer = MessageReader.Get(testData); + + // Length is just the malicious message header. + outer.Length = 2; + + try + { + outer.ReadMessage(); + Assert.Fail("ReadMessage is expected to throw"); + } + catch (InvalidDataException) { } + } + + [TestMethod] + public void GetLittleEndian() + { + Assert.IsTrue(MessageWriter.IsLittleEndian()); + } + + } + +}
\ No newline at end of file diff --git a/Projects/Message/Message/MessageWriter.cs b/Projects/Message/Message/MessageWriter.cs index 845a901..68a5ac3 100644 --- a/Projects/Message/Message/MessageWriter.cs +++ b/Projects/Message/Message/MessageWriter.cs @@ -90,6 +90,13 @@ namespace MultiplayerToolkit this.Length = this.Position; } +#if UNIT_TEST + public void StartMessage(int dummy) + { + StartMessage(); + } +#endif + /// <summary> /// ôϢȣ־Ϣ /// </summary> |