aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorchai <215380520@qq.com>2023-10-14 12:43:01 +0800
committerchai <215380520@qq.com>2023-10-14 12:43:01 +0800
commit1931354421132dbe9b7a44be0219c3e33be831fd (patch)
treeaf16b9b3a905d28a2ac57490dc62353623245130
parent949f5accd281a051e5e2eeb3df8767b4794b6a1e (diff)
+ Message UnitTest
-rw-r--r--Projects/Message/Message/Message.csproj1
-rw-r--r--Projects/Message/Message/MessageReader.cs33
-rw-r--r--Projects/Message/Message/MessageReaderTests.cs683
-rw-r--r--Projects/Message/Message/MessageWriter.cs7
4 files changed, 715 insertions, 9 deletions
diff --git a/Projects/Message/Message/Message.csproj b/Projects/Message/Message/Message.csproj
index a15b051..56698ef 100644
--- a/Projects/Message/Message/Message.csproj
+++ b/Projects/Message/Message/Message.csproj
@@ -53,6 +53,7 @@
<ItemGroup>
<Compile Include="IRecyclable.cs" />
<Compile Include="MessageReader.cs" />
+ <Compile Include="MessageReaderTests.cs" />
<Compile Include="MessageWriter.cs" />
<Compile Include="MessageWriterTests.cs" />
<Compile Include="ObjectPool.cs" />
diff --git a/Projects/Message/Message/MessageReader.cs b/Projects/Message/Message/MessageReader.cs
index b49f9e2..32b0218 100644
--- a/Projects/Message/Message/MessageReader.cs
+++ b/Projects/Message/Message/MessageReader.cs
@@ -88,7 +88,7 @@ namespace MultiplayerToolkit
}
/// <summary>
- /// 将缓冲区封装为根MessageReader,不拷贝
+ /// 将缓冲区封装为根消息,不拷贝
/// </summary>
/// <param name="buffer"></param>
/// <returns></returns>
@@ -106,7 +106,7 @@ namespace MultiplayerToolkit
}
/// <summary>
- /// 拷贝buffer,封装为根MessageReader
+ /// 将缓冲区封装为根消息,拷贝
/// </summary>
/// <param name="buffer"></param>
/// <returns></returns>
@@ -169,9 +169,25 @@ namespace MultiplayerToolkit
return output;
}
+
+ public static MessageReader Copy(MessageReader source)
+ {
+ var output = MessageReader.GetSized(source.Buffer.Length);
+ System.Buffer.BlockCopy(source.Buffer, 0, output.Buffer, 0, source.Buffer.Length);
+
+ output.Offset = source.Offset;
+
+ output._position = source._position;
+ output.readHead = source.readHead;
+
+ output.Length = source.Length;
+
+ return output;
+ }
+
#endif
-#endregion
+ #endregion
/// <summary>
/// 返回一个可读子消息,引用了父缓冲区,**不要回收**
@@ -179,7 +195,7 @@ namespace MultiplayerToolkit
public MessageReader ReadMessage()
{
// 至少有一个length
- if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+ if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 2 of {this.BytesRemaining}");
MessageReader output = new MessageReader();
@@ -207,7 +223,7 @@ namespace MultiplayerToolkit
/// </summary>
public MessageReader ReadMessageAsNewBuffer()
{
- if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 3 of {this.BytesRemaining}");
+ if (this.BytesRemaining < 2) throw new InvalidDataException($"ReadMessage header is longer than message length: 2 of {this.BytesRemaining}");
var len = this.ReadUInt16(); // Position += 2
@@ -218,7 +234,6 @@ namespace MultiplayerToolkit
Array.Copy(this.Buffer, this.readHead, output.Buffer, 0, len);
output.Length = len;
- //output.Tag = tag;
this.Position += output.Length;
return output;
@@ -244,7 +259,7 @@ namespace MultiplayerToolkit
{
var headerOffset = reader.Offset - 2;
var endOfMessage = reader.Offset + reader.Length; // 有效部分的后一个byte
- var len = this.TotalLength - endOfMessage;
+ var len = reader.Buffer.Length - endOfMessage;
temp = MessageReader.GetSized(len);
Array.Copy(reader.Buffer, endOfMessage, temp.Buffer, 0, len);
@@ -271,8 +286,8 @@ namespace MultiplayerToolkit
var temp = MessageReader.GetSized(reader.Buffer.Length);
try
{
- var headerOffset = reader.Offset - 3; // headerOffset是length+tag,这个方法仅仅接受reader不含sendoption的情况
- var startOfMessage = reader.Offset; // 头部后面的数据开始的索引
+ var headerOffset = reader.Offset - 2;
+ var startOfMessage = reader.Offset;
var len = reader.Buffer.Length - headerOffset; // 疑似写错了,应该是headerOffset
int writerOffset = 0;
diff --git a/Projects/Message/Message/MessageReaderTests.cs b/Projects/Message/Message/MessageReaderTests.cs
new file mode 100644
index 0000000..d6f39d2
--- /dev/null
+++ b/Projects/Message/Message/MessageReaderTests.cs
@@ -0,0 +1,683 @@
+using System;
+using System.IO;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using MultiplayerToolkit;
+
+namespace MultiplayerToolkit
+{
+ [TestClass]
+ public class MessageReaderTests
+ {
+ [TestMethod]
+ public void ReadProperInt()
+ {
+ const int Test1 = int.MaxValue;
+ const int Test2 = int.MinValue;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(10, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+ Assert.AreEqual(Test1, reader.ReadInt32());
+ Assert.AreEqual(Test2, reader.ReadInt32());
+ }
+
+ [TestMethod]
+ public void ReadProperBool()
+ {
+ const bool Test1 = true;
+ const bool Test2 = false;
+
+ var msg = new MessageWriter(128);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ Assert.AreEqual(4, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadBoolean());
+ Assert.AreEqual(Test2, reader.ReadBoolean());
+
+ }
+
+ [TestMethod]
+ public void ReadProperString()
+ {
+ const string Test1 = "Hello";
+ string Test2 = new string(' ', 1024);
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.Write(string.Empty);
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadString());
+ Assert.AreEqual(Test2, reader.ReadString());
+ Assert.AreEqual(string.Empty, reader.ReadString());
+
+ }
+
+ [TestMethod]
+ public void ReadProperFloat()
+ {
+ const float Test1 = 12.34f;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(Test1);
+ msg.EndMessage();
+
+ Assert.AreEqual(6, msg.Length);
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+
+ Assert.AreEqual(Test1, reader.ReadSingle());
+ }
+
+ [TestMethod]
+ public void RemoveMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+ reader.Length = msg.Length;
+
+ var zero = reader.ReadMessage();
+
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+ two.RemoveMessage(three);
+
+ // Reader becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Unrealistic, but nice. Earlier data is not affected
+ Assert.AreEqual(Test0, zero.ReadByte());
+
+ // Continuing to read depth-first works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorks()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get();
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWorksWithSendOptionNone()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get();
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ }
+
+ [TestMethod]
+ public void InsertMessageWithoutStartMessageInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get();
+ writer.Write(TestInsert);
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ Assert.AreEqual(TestInsert, two.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageWithMultipleMessagesInWriter()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte TestInsert = 66;
+ const byte TestInsert2 = 77;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get();
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ writer.StartMessage(6);
+ writer.Write(TestInsert2);
+ writer.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ //set the position back to zero to read back the updated message
+ reader.Position = 0;
+
+ var zero = reader.ReadMessage();
+ Assert.AreEqual(Test0, zero.ReadByte());
+ one = reader.ReadMessage();
+ two = one.ReadMessage();
+ var insert = two.ReadMessage();
+ Assert.AreEqual(TestInsert, insert.ReadByte());
+ var insert2 = two.ReadMessage();
+ Assert.AreEqual(TestInsert2, insert2.ReadByte());
+ three = two.ReadMessage();
+ Assert.AreEqual(Test3, three.ReadByte());
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = reader.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+ }
+
+ [TestMethod]
+ public void InsertMessageMultipleInsertsWithoutReset()
+ {
+ const byte Test0 = 11;
+ const byte Test3 = 33;
+ const byte Test4 = 44;
+ const byte Test5 = 55;
+ const byte Test6 = 66;
+ const byte TestInsert = 77;
+ const byte TestInsert2 = 88;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(0);
+ msg.Write(Test0);
+ msg.EndMessage();
+
+ msg.StartMessage(12);
+ msg.StartMessage(23);
+
+ msg.StartMessage(34);
+ msg.Write(Test3);
+ msg.EndMessage();
+
+ msg.StartMessage(45);
+ msg.Write(Test4);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(56);
+ msg.Write(Test5);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ msg.StartMessage(67);
+ msg.Write(Test6);
+ msg.EndMessage();
+
+ MessageReader reader = MessageReader.Get(msg.Buffer);
+
+ MessageWriter writer = MessageWriter.Get();
+ writer.StartMessage(5);
+ writer.Write(TestInsert);
+ writer.EndMessage();
+
+ MessageWriter writer2 = MessageWriter.Get();
+ writer2.StartMessage(6);
+ writer2.Write(TestInsert2);
+ writer2.EndMessage();
+
+ reader.ReadMessage();
+ var one = reader.ReadMessage();
+ var two = one.ReadMessage();
+ var three = two.ReadMessage();
+
+ two.InsertMessage(three, writer);
+
+ // three becomes invalid
+ Assert.AreNotEqual(Test3, three.ReadByte());
+
+ // Continuing to read works
+ var four = two.ReadMessage();
+ Assert.AreEqual(Test4, four.ReadByte());
+
+ var five = one.ReadMessage();
+ Assert.AreEqual(Test5, five.ReadByte());
+
+ reader.InsertMessage(one, writer2);
+
+ var six = reader.ReadMessage();
+ Assert.AreEqual(Test6, six.ReadByte());
+ }
+
+ private void SetZero(MessageReader reader)
+ {
+ for (int i = 0; i < reader.Buffer.Length; ++i)
+ reader.Buffer[i] = 0;
+ }
+
+ [TestMethod]
+ public void CopySubMessage()
+ {
+ const byte Test1 = 12;
+ const byte Test2 = 146;
+
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+
+ msg.StartMessage(2);
+ msg.Write(Test1);
+ msg.Write(Test2);
+ msg.EndMessage();
+
+ msg.EndMessage();
+
+ MessageReader handleMessage = MessageReader.GetChildMessage(msg.Buffer, 0);
+
+ var parentReader = MessageReader.Copy(handleMessage);
+
+ handleMessage.Recycle();
+ SetZero(handleMessage);
+
+ for (int i = 0; i < 5; ++i)
+ {
+
+ var reader = parentReader.ReadMessage();
+ Assert.AreEqual(Test1, reader.ReadByte());
+ Assert.AreEqual(Test2, reader.ReadByte());
+
+ var temp = parentReader;
+ parentReader = MessageReader.CopyMessageIntoParent(reader);
+
+ temp.Recycle();
+ SetZero(temp);
+ SetZero(reader);
+ }
+ }
+
+
+ [TestMethod]
+ public void ReadMessageLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(2);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessage();
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub = reader.ReadMessage();
+ Assert.AreEqual(0, sub.Length);
+ }
+
+
+ [TestMethod]
+ public void ReadMessageAsNewBufferLength()
+ {
+ var msg = new MessageWriter(2048);
+ msg.StartMessage(1);
+ msg.Write(65534);
+ msg.StartMessage(2);
+ msg.Write("HO");
+ msg.EndMessage();
+ msg.StartMessage(232);
+ msg.EndMessage();
+ msg.EndMessage();
+
+ Assert.AreEqual(msg.Length, msg.Position);
+
+ MessageReader reader = MessageReader.GetChildMessage(msg.Buffer, 0);
+ Assert.AreEqual(65534, reader.ReadInt32()); // Content
+
+ var sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(3, sub.Length);
+ Assert.AreEqual("HO", sub.ReadString());
+
+ sub.Recycle();
+
+ sub = reader.ReadMessageAsNewBuffer();
+ Assert.AreEqual(0, sub.Position);
+ Assert.AreEqual(0, sub.Offset);
+
+ Assert.AreEqual(0, sub.Length);
+ sub.Recycle();
+ }
+
+
+ [TestMethod]
+ public void ReadStringProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ int DataLength = TestDataFromAPreviousPacket.Length + 1;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a string from
+ // an offset, it reads the packed int at that offset, treats it
+ // as a length and then proceeds to read the data that comes after
+ // it without any bound checks. This can be chained with something
+ // else to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get();
+
+ // This will be our malicious "string length"
+ writer.WritePacked(DataLength);
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray();
+
+ // One extra byte for the MessageWriter header, one more for the malicious data
+ Assert.AreEqual(DataLength + 1, testData.Length);
+
+ var dut = MessageReader.Get(testData);
+
+ // If Length is short by even a byte, ReadString should obey that.
+ dut.Length--;
+
+ try
+ {
+ dut.ReadString();
+ Assert.Fail("ReadString is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void ReadMessageProtectsAgainstOverrun()
+ {
+ const string TestDataFromAPreviousPacket = "You shouldn't be able to see this data";
+
+ // An extra byte from the length of TestData when written via MessageWriter
+ // Extra 3 bytes for the length header for ReadMessage.
+ int DataLength = TestDataFromAPreviousPacket.Length + 1 + 2;
+
+ // THE BUG
+ //
+ // No bound checks. When the server wants to read a message, it
+ // reads the uint16 at that offset, treats it as a length without any bound checks.
+ // This can be allow a later ReadString or ReadBytes to create an infoleak.
+
+ MessageWriter writer = MessageWriter.Get();
+
+ // This is the malicious length. No data in this message, so it should be zero.
+ writer.Write((ushort)1); // length
+
+ // This is data from a "previous packet"
+ writer.Write(TestDataFromAPreviousPacket);
+
+ byte[] testData = writer.ToByteArray();
+
+ Assert.AreEqual(DataLength, testData.Length);
+
+ var outer = MessageReader.Get(testData);
+
+ // Length is just the malicious message header.
+ outer.Length = 2;
+
+ try
+ {
+ outer.ReadMessage();
+ Assert.Fail("ReadMessage is expected to throw");
+ }
+ catch (InvalidDataException) { }
+ }
+
+ [TestMethod]
+ public void GetLittleEndian()
+ {
+ Assert.IsTrue(MessageWriter.IsLittleEndian());
+ }
+
+ }
+
+} \ No newline at end of file
diff --git a/Projects/Message/Message/MessageWriter.cs b/Projects/Message/Message/MessageWriter.cs
index 845a901..68a5ac3 100644
--- a/Projects/Message/Message/MessageWriter.cs
+++ b/Projects/Message/Message/MessageWriter.cs
@@ -90,6 +90,13 @@ namespace MultiplayerToolkit
this.Length = this.Position;
}
+#if UNIT_TEST
+ public void StartMessage(int dummy)
+ {
+ StartMessage();
+ }
+#endif
+
/// <summary>
/// ôϢȣ־Ϣ
/// </summary>