From 57c490da6e4fcf44dc15c0c1160084996060e794 Mon Sep 17 00:00:00 2001 From: Sijmen Schoon Date: Mon, 28 Dec 2020 18:35:38 +0100 Subject: [PATCH] Add buffer size checks to Net::Icmp --- src/net-icmp.cpp | 98 +++++++++++++++++++++++++++++------------------- src/net-icmp.h | 18 +++++---- 2 files changed, 70 insertions(+), 46 deletions(-) diff --git a/src/net-icmp.cpp b/src/net-icmp.cpp index 223f044..7948772 100644 --- a/src/net-icmp.cpp +++ b/src/net-icmp.cpp @@ -9,16 +9,18 @@ namespace Net::Icmp { // - // PacketHeader + // Header // - PacketHeader::PacketHeader() {} + Header::Header() : type(static_cast(0)), code(0), checksum(0) {} + Header::Header(Type type, uint8_t code) : type(type), code(code), checksum(0) {} - PacketHeader::PacketHeader(Type type, uint8_t code) : - type(type), code(code), checksum(0) - {} - - size_t PacketHeader::Serialize(uint8_t* buffer) const + size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(type); buffer[i++] = code; @@ -27,13 +29,18 @@ namespace Net::Icmp return i; } - PacketHeader PacketHeader::Deserialize(const uint8_t* buffer) - { - PacketHeader self; - self.type = static_cast(buffer[0]); - self.code = buffer[1]; - self.checksum = buffer[2] << 8 | buffer[3]; - return self; + size_t Header::Deserialize( + Header& out, const uint8_t* buffer, const size_t bufferSize + ) { + if (bufferSize < SerializedLength()) + { + return 0; + } + + out.type = static_cast(buffer[0]); + out.code = buffer[1]; + out.checksum = buffer[2] << 8 | buffer[3]; + return 4; } // @@ -43,8 +50,13 @@ namespace Net::Icmp EchoHeader::EchoHeader(uint16_t identifier, uint16_t sequenceNumber) : identifier(identifier), sequenceNumber(sequenceNumber) {} - size_t EchoHeader::Serialize(uint8_t* buffer) const + size_t EchoHeader::Serialize(uint8_t* buffer, const size_t bufferSize) const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = identifier >> 8; buffer[i++] = identifier; @@ -53,20 +65,25 @@ namespace Net::Icmp return i; } - EchoHeader EchoHeader::Deserialize(const uint8_t* buffer) - { - EchoHeader self; - self.identifier = buffer[0] << 8 | buffer[1]; - self.sequenceNumber = buffer[2] << 8 | buffer[3]; - return self; + size_t EchoHeader::Deserialize( + EchoHeader& out, const uint8_t* buffer, const size_t bufferSize + ) { + if (bufferSize < SerializedLength()) + { + return 0; + } + + out.identifier = buffer[0] << 8 | buffer[1]; + out.sequenceNumber = buffer[2] << 8 | buffer[3]; + return 4; } void SendEchoRequest(Utils::MacAddress mac, uint32_t ip) { - Icmp::PacketHeader icmpHeader(Icmp::Type::EchoRequest, 0); + Icmp::Header icmpHeader(Icmp::Type::EchoRequest, 0); Icmp::EchoHeader pingHeader(0, 0); - size_t ipv4TotalSize = Icmp::PacketHeader::SerializedLength() + + size_t ipv4TotalSize = Icmp::Header::SerializedLength() + Icmp::EchoHeader::SerializedLength() + Ipv4::Header::SerializedLength(); Ipv4::Header ipv4Header( @@ -80,8 +97,8 @@ namespace Net::Icmp size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ipv4Header.Serialize(buffer + size); - size += pingHeader.Serialize(buffer + size); - size += icmpHeader.Serialize(buffer + 1); + size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size); const auto expectedSize = ethernetHeader.SerializedLength() + @@ -97,14 +114,17 @@ namespace Net::Icmp static void handleEchoRequest( const Ethernet::Header& reqEthernetHeader, const Ipv4::Header& reqIpv4Header, - const Icmp::PacketHeader& reqIcmpHeader, + const Icmp::Header& reqIcmpHeader, const uint8_t* reqBuffer, const size_t reqBufferSize ) { - const auto reqEchoHeader = Icmp::EchoHeader::Deserialize(reqBuffer); - const auto reqHeaderSize = reqEchoHeader.SerializedLength(); + EchoHeader reqEchoHeader; + const auto reqEchoHeaderSize = + Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize); + assert(reqEchoHeaderSize == reqEchoHeader.SerializedLength()); + assert(reqEchoHeaderSize <= reqBufferSize); - const Icmp::PacketHeader respIcmpHeader(Icmp::Type::EchoReply, 0); + const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0); const Ipv4::Header respIpv4Header( Ipv4::Protocol::Icmp, Utils::Ipv4Address, @@ -117,11 +137,11 @@ namespace Net::Icmp Ethernet::EtherType::Ipv4 ); - const auto payloadLength = + const auto payloadSize = reqIpv4Header.totalLength - reqIpv4Header.SerializedLength() - reqIcmpHeader.SerializedLength() - - reqEchoHeader.SerializedLength(); + reqEchoHeaderSize; std::array respBuffer; @@ -129,19 +149,20 @@ namespace Net::Icmp respSize += respEthernetHeader.Serialize( respBuffer.data() + respSize, respBuffer.size() - respSize); respSize += respIpv4Header.Serialize(respBuffer.data() + respSize); - respSize += respIcmpHeader.Serialize(respBuffer.data() + respSize); + respSize += respIcmpHeader.Serialize( + respBuffer.data() + respSize, respBuffer.size() - respSize); std::memcpy( respBuffer.data() + respSize, - reqBuffer + reqHeaderSize, - payloadLength + reqBuffer + reqEchoHeaderSize, + payloadSize ); - respSize += payloadLength; + respSize += payloadSize; const auto expectedRespSize = respEthernetHeader.SerializedLength() + respIpv4Header.SerializedLength() + respIcmpHeader.SerializedLength() + - payloadLength; + payloadSize; assert(respSize == expectedRespSize); assert(respSize <= respBuffer.size()); @@ -160,8 +181,9 @@ namespace Net::Icmp const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize); headerSize += ipv4Header.SerializedLength(); - const auto icmpHeader = Icmp::PacketHeader::Deserialize(buffer + headerSize); - headerSize += icmpHeader.SerializedLength(); + Header icmpHeader; + headerSize += Icmp::Header::Deserialize( + icmpHeader, buffer + headerSize, bufferSize - headerSize); const auto expectedHeaderSize = ethernetHeader.SerializedLength() + diff --git a/src/net-icmp.h b/src/net-icmp.h index 46d1d30..570b176 100644 --- a/src/net-icmp.h +++ b/src/net-icmp.h @@ -9,22 +9,23 @@ namespace Net::Icmp EchoRequest = 8, }; - struct PacketHeader + struct Header { Type type; uint8_t code; uint16_t checksum; - PacketHeader(); - PacketHeader(Type type, uint8_t code); + Header(); + Header(Type type, uint8_t code); constexpr static size_t SerializedLength() { return sizeof(type) + sizeof(code) + sizeof(checksum); } - size_t Serialize(uint8_t* buffer) const; - static PacketHeader Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const; + static size_t Deserialize( + Header& out, const uint8_t* buffer, const size_t bufferSize); }; struct EchoHeader @@ -40,10 +41,11 @@ namespace Net::Icmp return sizeof(identifier) + sizeof(sequenceNumber); } - size_t Serialize(uint8_t* buffer) const; - static EchoHeader Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const; + static size_t Deserialize( + EchoHeader& out, const uint8_t* buffer, const size_t bufferSize); }; - void SendEchoRequest(Utils::MacAddress mac, uint32_t ip); + void SendEchoRequest(const Utils::MacAddress mac, const uint32_t ip); void HandlePacket(const uint8_t* buffer, const size_t bufferSize); } // namespace Net::Icmp