From adf5172e94066e75f234c9e479bd9d3239eda74b Mon Sep 17 00:00:00 2001 From: Sijmen Schoon Date: Mon, 28 Dec 2020 23:44:39 +0100 Subject: [PATCH] More buffer size checks --- src/net-arp.cpp | 2 +- src/net-dhcp.cpp | 32 +++++----- src/net-icmp.cpp | 33 +++++++--- src/net-ipv4.cpp | 63 ++++++++++++------- src/net-ipv4.h | 7 ++- src/net-tftp.cpp | 156 +++++++++++++++++++++++++++++++---------------- src/net-tftp.h | 31 +++++----- src/net-udp.cpp | 53 ++++++++++++---- src/net-udp.h | 4 +- 9 files changed, 250 insertions(+), 131 deletions(-) diff --git a/src/net-arp.cpp b/src/net-arp.cpp index 1af1dcb..3699556 100644 --- a/src/net-arp.cpp +++ b/src/net-arp.cpp @@ -94,7 +94,7 @@ namespace Net::Arp const auto expectedSize = ethernetHeader.SerializedLength() + arpPacket.SerializedLength(); assert(size == expectedSize); - assert(size <= USPI_FRAME_BUFFER_SIZE); + assert(size <= sizeof(buffer)); USPiSendFrame(buffer, size); } diff --git a/src/net-dhcp.cpp b/src/net-dhcp.cpp index 876859d..19ddbd4 100644 --- a/src/net-dhcp.cpp +++ b/src/net-dhcp.cpp @@ -7,6 +7,7 @@ #include "net-ipv4.h" #include "net-ethernet.h" +#include "debug.h" #include "types.h" #include #include @@ -150,8 +151,8 @@ namespace Net::Dhcp uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; size_t size = 0; size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); - size += ipv4Header.Serialize(buffer + size); - size += udpHeader.Serialize(buffer + size); + size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size); + size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size); size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); const auto expectedSize = @@ -202,8 +203,8 @@ namespace Net::Dhcp size_t size = 0; size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); - size += ipv4Header.Serialize(buffer + size); - size += udpHeader.Serialize(buffer + size); + size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size); + size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size); size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); const auto expectedSize = @@ -249,26 +250,29 @@ namespace Net::Dhcp const uint8_t* buffer, size_t size ) { - auto dhcpHeader = Header(); - const auto dhcpSize = Header::Deserialize(dhcpHeader, buffer, size); - if (dhcpSize == 0) + Header header; + const auto dhcpSize = Header::Deserialize(header, buffer, size); + if (dhcpSize != Header::SerializedLength()) { - // TODO log + DEBUG_LOG( + "Dropped DHCP packet (invalid buffer size %lu, expected %lu)\r\n", + size, Header::SerializedLength() + ); return; } - if (dhcpHeader.opcode != Opcode::BootReply) return; - if (dhcpHeader.hardwareAddressType != 1) return; - if (dhcpHeader.hardwareAddressLength != 6) return; - if (dhcpHeader.transactionId != transactionId) return; + if (header.opcode != Opcode::BootReply) return; + if (header.hardwareAddressType != 1) return; + if (header.hardwareAddressLength != 6) return; + if (header.transactionId != transactionId) return; if (!serverSelected) { - handleOfferPacket(ethernetHeader, dhcpHeader); + handleOfferPacket(ethernetHeader, header); } else { - handleAckPacket(ethernetHeader, dhcpHeader); + handleAckPacket(ethernetHeader, header); } } } // namespace Net::Dhcp diff --git a/src/net-icmp.cpp b/src/net-icmp.cpp index 7948772..ae66116 100644 --- a/src/net-icmp.cpp +++ b/src/net-icmp.cpp @@ -3,6 +3,7 @@ #include "net-icmp.h" +#include "debug.h" #include "types.h" #include @@ -96,7 +97,7 @@ namespace Net::Icmp size_t size = 0; size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); - size += ipv4Header.Serialize(buffer + size); + size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size); size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size); size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size); @@ -121,8 +122,15 @@ namespace Net::Icmp EchoHeader reqEchoHeader; const auto reqEchoHeaderSize = Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize); - assert(reqEchoHeaderSize == reqEchoHeader.SerializedLength()); - assert(reqEchoHeaderSize <= reqBufferSize); + if (reqEchoHeaderSize == 0 || reqBufferSize < reqEchoHeaderSize) + { + DEBUG_LOG( + "Dropped ICMP packet " + "(invalid buffer size %ul, expected at least %ul)\r\n", + reqBufferSize, EchoHeader::SerializedLength() + ); + return; + } const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0); const Ipv4::Header respIpv4Header( @@ -148,7 +156,8 @@ namespace Net::Icmp size_t respSize = 0; respSize += respEthernetHeader.Serialize( respBuffer.data() + respSize, respBuffer.size() - respSize); - respSize += respIpv4Header.Serialize(respBuffer.data() + respSize); + respSize += respIpv4Header.Serialize( + respBuffer.data() + respSize, respBuffer.size() - respSize); respSize += respIcmpHeader.Serialize( respBuffer.data() + respSize, respBuffer.size() - respSize); std::memcpy( @@ -178,8 +187,9 @@ namespace Net::Icmp headerSize += Ethernet::Header::Deserialize( ethernetHeader, buffer + headerSize, bufferSize - headerSize); - const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize); - headerSize += ipv4Header.SerializedLength(); + Ipv4::Header ipv4Header; + headerSize += Ipv4::Header::Deserialize( + ipv4Header, buffer + headerSize, bufferSize - headerSize); Header icmpHeader; headerSize += Icmp::Header::Deserialize( @@ -189,9 +199,16 @@ namespace Net::Icmp ethernetHeader.SerializedLength() + ipv4Header.SerializedLength() + icmpHeader.SerializedLength(); - assert(headerSize == expectedHeaderSize); + if (headerSize != expectedHeaderSize) + { + DEBUG_LOG( + "Dropped ICMP packet " + "(invalid buffer size %ul, expected at least %ul)\r\n", + bufferSize, expectedHeaderSize + ); + } - if (icmpHeader.type == Icmp::Type::EchoRequest) + if (icmpHeader.type == Type::EchoRequest) { handleEchoRequest( ethernetHeader, diff --git a/src/net-ipv4.cpp b/src/net-ipv4.cpp index f94b160..af40b02 100644 --- a/src/net-ipv4.cpp +++ b/src/net-ipv4.cpp @@ -1,3 +1,5 @@ +#include + #include "net-ipv4.h" #include "net-ethernet.h" #include "net-arp.h" @@ -5,6 +7,8 @@ #include "net-udp.h" #include "net-utils.h" +#include "debug.h" + namespace Net::Ipv4 { Header::Header() {} @@ -27,10 +31,14 @@ namespace Net::Ipv4 destinationIp(destinationIp) {} - size_t Header::Serialize(uint8_t* buffer) const + size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const { - size_t i = 0; + if (bufferSize <= SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = version << 4 | ihl; buffer[i++] = dscp << 2 | ecn; buffer[i++] = totalLength >> 8; @@ -62,30 +70,35 @@ namespace Net::Ipv4 return i; } - Header Header::Deserialize(const uint8_t* buffer) - { - Header self; - self.version = buffer[0] >> 4; - self.ihl = buffer[0] & 0x0F; + size_t Header::Deserialize( + Header& out, const uint8_t* buffer, const size_t bufferSize + ) { + if (bufferSize <= SerializedLength()) + { + return 0; + } - self.dscp = buffer[1] >> 2; - self.ecn = buffer[1] & 0x03; + out.version = buffer[0] >> 4; + out.ihl = buffer[0] & 0x0F; - self.totalLength = buffer[2] << 8 | buffer[3]; - self.identification = buffer[4] << 8 | buffer[5]; + out.dscp = buffer[1] >> 2; + out.ecn = buffer[1] & 0x03; - self.flags = buffer[6] >> 5; - self.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7]; + out.totalLength = buffer[2] << 8 | buffer[3]; + out.identification = buffer[4] << 8 | buffer[5]; - self.ttl = buffer[8]; - self.protocol = static_cast(buffer[9]); - self.headerChecksum = buffer[10] << 8 | buffer[11]; + out.flags = buffer[6] >> 5; + out.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7]; - self.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15]; - self.destinationIp = + out.ttl = buffer[8]; + out.protocol = static_cast(buffer[9]); + out.headerChecksum = buffer[10] << 8 | buffer[11]; + + out.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15]; + out.destinationIp = buffer[16] << 24 | buffer[17] << 16 | buffer[18] << 8 | buffer[19]; - return self; + return 20; } void HandlePacket( @@ -93,8 +106,16 @@ namespace Net::Ipv4 const uint8_t* buffer, const size_t bufferSize ) { - const auto header = Header::Deserialize(buffer); - const auto headerSize = Header::SerializedLength(); + Header header; + const auto headerSize = Header::Deserialize(header, buffer, bufferSize); + if (headerSize != Header::SerializedLength()) + { + DEBUG_LOG( + "Dropped IPv4 packet (invalid buffer size %lu, expected at least %lu)\r\n" + bufferSize, headerSize + ); + return; + } // Update ARP table Arp::ArpTable.insert( diff --git a/src/net-ipv4.h b/src/net-ipv4.h index 2089b82..a1ea9f4 100644 --- a/src/net-ipv4.h +++ b/src/net-ipv4.h @@ -42,13 +42,14 @@ namespace Net::Ipv4 return 20; } - size_t Serialize(uint8_t* buffer) const; - static Header Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const; + static size_t Deserialize( + Header& out, const uint8_t* buffer, const size_t bufferSize); }; void HandlePacket( const Ethernet::Header& ethernetHeader, const uint8_t* buffer, - const size_t size + const size_t bufferSize ); } // namespace Net::Ipv4 diff --git a/src/net-tftp.cpp b/src/net-tftp.cpp index 6b9f4b0..c819b2d 100644 --- a/src/net-tftp.cpp +++ b/src/net-tftp.cpp @@ -9,6 +9,7 @@ #include "net-udp.h" #include "net.h" +#include "debug.h" #include "ff.h" #include "types.h" #include @@ -20,16 +21,28 @@ namespace Net::Tftp static bool shouldReboot = false; static uint32_t currentBlockNumber = -1; - static std::unique_ptr handleTftpWriteRequest(const uint8_t* data) - { - auto packet = WriteReadRequestPacket::Deserialize(data); + Packet::Packet() : opcode(static_cast(0)) {} + Packet::Packet(const Opcode opcode) : opcode(opcode) {} + + static std::unique_ptr handleTftpWriteRequest( + const uint8_t* data, const size_t dataSize + ) { + WriteReadRequestPacket packet; + const auto size = packet.Deserialize(data, dataSize); + if (size == 0) + { + DEBUG_LOG( + "Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n", + dataSize, sizeof(WriteReadRequestPacket::opcode) + 2 + ) + return nullptr; + } // TODO Implement netscii, maybe if (packet.mode != "octet") { - return std::unique_ptr( - new ErrorPacket(0, "please use mode octet") - ); + const auto pointer = new ErrorPacket(0, "please use mode octet"); + return std::unique_ptr(pointer); } currentBlockNumber = 0; @@ -79,10 +92,14 @@ namespace Net::Tftp static std::unique_ptr handleTftpData(const uint8_t* buffer, size_t size) { DataPacket packet; - const auto tftpSize = DataPacket::Deserialize(packet, buffer, size); - if (size == 0) + const auto tftpSize = packet.Deserialize(buffer, size); + if (tftpSize == 0) { - // TODO log + DEBUG_LOG( + "Dropped TFTP data packet " + "(invalid buffer size %lu, expected at least %lu)\r\n", + size, sizeof(packet.opcode) + sizeof(packet.blockNumber) + ) return nullptr; } @@ -120,20 +137,28 @@ namespace Net::Tftp const Ethernet::Header ethernetReqHeader, const Ipv4::Header ipv4ReqHeader, const Udp::Header udpReqHeader, - const uint8_t* data + const uint8_t* reqBuffer, + const size_t reqBufferSize ) { - const auto opcode = static_cast(data[0] << 8 | data[1]); + const auto opcode = static_cast(reqBuffer[0] << 8 | reqBuffer[1]); std::unique_ptr response; - bool last = false; + + const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength(); + if (reqBufferSize < payloadSize) + { + DEBUG_LOG( + "Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n", + reqBufferSize, payloadSize + ); + } if (opcode == Opcode::WriteRequest) { - response = handleTftpWriteRequest(data); + response = handleTftpWriteRequest(reqBuffer, payloadSize); } else if (opcode == Opcode::Data) { - const auto length = udpReqHeader.length - Udp::Header::SerializedLength(); - response = handleTftpData(data, length); + response = handleTftpData(reqBuffer, payloadSize); } else { @@ -164,9 +189,9 @@ namespace Net::Tftp size_t size = 0; uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size); - size += ipv4RespHeader.Serialize(buffer + size); - size += udpRespHeader.Serialize(buffer + size); - size += response->Serialize(buffer + size); + size += ipv4RespHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += udpRespHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += response->Serialize(buffer + size, sizeof(buffer) - size); const auto expectedSize = ethernetRespHeader.SerializedLength() + @@ -179,28 +204,29 @@ namespace Net::Tftp USPiSendFrame(buffer, size); } - if (last && shouldReboot) - { - // TODO eww - extern void Reboot_Pi(); - Reboot_Pi(); - } + // TODO Reboot the Pi when a system file was received } // // WriteReadRequestPacket // - WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) : - Packet(opcode) + WriteReadRequestPacket::WriteReadRequestPacket() : Packet() {} + WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) : Packet(opcode) {} size_t WriteReadRequestPacket::SerializedLength() const { - return Packet::SerializedLength() + filename.size() + 1 + mode.size() + 1; + return sizeof(opcode) + filename.size() + 1 + mode.size() + 1; } - size_t WriteReadRequestPacket::Serialize(uint8_t* buffer) const + size_t WriteReadRequestPacket::Serialize(uint8_t* buffer, const size_t bufferSize) + const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(opcode) >> 8; buffer[i++] = static_cast(opcode); @@ -214,21 +240,28 @@ namespace Net::Tftp return i; } - WriteReadRequestPacket WriteReadRequestPacket::Deserialize(const uint8_t* buffer) - { + size_t WriteReadRequestPacket::Deserialize( + const uint8_t* buffer, const size_t bufferSize + ) { + // Can't use SerializedLength here, as it's variable. + // Check for each field instead. size_t i = 0; - const auto opcode = static_cast(buffer[i] << 8 | buffer[i + 1]); - WriteReadRequestPacket self(opcode); + if (sizeof(Opcode) >= bufferSize - i) return 0; + opcode = static_cast(buffer[i] << 8 | buffer[i + 1]); i += 2; - self.filename = reinterpret_cast(buffer + i); - i += self.filename.size() + 1; + const char* filenameStr = reinterpret_cast(buffer + i); + if (std::strlen(filenameStr) + 1 >= bufferSize - i) return 0; + filename = std::string(filenameStr); + i += filename.size() + 1; - self.mode = reinterpret_cast(buffer + i); - i += self.mode.size() + 1; + const char* modeStr = reinterpret_cast(buffer + i); + if (std::strlen(modeStr) + 1 >= bufferSize - i) return 0; + mode = std::string(modeStr); + i += mode.size() + 1; - return self; + return i; } // @@ -241,11 +274,16 @@ namespace Net::Tftp size_t ErrorPacket::SerializedLength() const { - return Packet::SerializedLength() + sizeof(errorCode) + message.size() + 1; + return sizeof(opcode) + sizeof(errorCode) + message.size() + 1; } - size_t ErrorPacket::Serialize(uint8_t* buffer) const + size_t ErrorPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(opcode) >> 8; buffer[i++] = static_cast(opcode); @@ -261,9 +299,7 @@ namespace Net::Tftp // // AcknowledgementPacket // - AcknowledgementPacket::AcknowledgementPacket() : - Packet(Opcode::Acknowledgement) - {} + AcknowledgementPacket::AcknowledgementPacket() : Packet(Opcode::Acknowledgement) {} AcknowledgementPacket::AcknowledgementPacket(uint16_t blockNumber) : Packet(Opcode::Acknowledgement), blockNumber(blockNumber) @@ -271,11 +307,17 @@ namespace Net::Tftp size_t AcknowledgementPacket::SerializedLength() const { - return Packet::SerializedLength() + sizeof(blockNumber); + return sizeof(opcode) + sizeof(blockNumber); } - size_t AcknowledgementPacket::Serialize(uint8_t* buffer) const + size_t AcknowledgementPacket::Serialize(uint8_t* buffer, const size_t bufferSize) + const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(opcode) >> 8; buffer[i++] = static_cast(opcode); @@ -290,8 +332,18 @@ namespace Net::Tftp DataPacket::DataPacket() : Packet(Opcode::Data), blockNumber(0) {} - size_t DataPacket::Serialize(uint8_t* buffer) const + size_t DataPacket::SerializedLength() const { + return sizeof(opcode) + sizeof(blockNumber) + data.size(); + } + + size_t DataPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const + { + if (bufferSize <= SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(opcode) >> 8; buffer[i++] = static_cast(opcode); @@ -304,17 +356,15 @@ namespace Net::Tftp return i; } - size_t DataPacket::Deserialize( - DataPacket& out, const uint8_t* buffer, size_t size - ) { - if (size < sizeof(opcode) + sizeof(blockNumber)) + size_t DataPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize) { + if (bufferSize < sizeof(opcode) + sizeof(blockNumber)) { return 0; } - out.opcode = static_cast(buffer[0] << 8 | buffer[1]); - out.blockNumber = buffer[2] << 8 | buffer[3]; - out.data = std::vector(buffer + 4, buffer + size); - return size; + opcode = static_cast(buffer[0] << 8 | buffer[1]); + blockNumber = buffer[2] << 8 | buffer[3]; + data = std::vector(buffer + 4, buffer + bufferSize); + return bufferSize; } } // namespace Net::Tftp diff --git a/src/net-tftp.h b/src/net-tftp.h index f1a8682..798e06a 100644 --- a/src/net-tftp.h +++ b/src/net-tftp.h @@ -21,15 +21,10 @@ namespace Net::Tftp { Opcode opcode; - Packet(Opcode opcode) : opcode(opcode) - {} - - virtual size_t SerializedLength() const - { - return sizeof(opcode); - } - - virtual size_t Serialize(uint8_t* buffer) const = 0; + Packet(); + Packet(Opcode opcode); + virtual size_t SerializedLength() const = 0; + virtual size_t Serialize(uint8_t* buffer, const size_t bufferSize) const = 0; }; struct WriteReadRequestPacket : public Packet @@ -37,10 +32,11 @@ namespace Net::Tftp std::string filename; std::string mode; + WriteReadRequestPacket(); WriteReadRequestPacket(const Opcode opcode); size_t SerializedLength() const override; - size_t Serialize(uint8_t* buffer) const override; - static WriteReadRequestPacket Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override; + size_t Deserialize(const uint8_t* buffer, const size_t bufferSize); }; struct ErrorPacket : public Packet @@ -51,7 +47,7 @@ namespace Net::Tftp ErrorPacket(); ErrorPacket(uint16_t errorCode, std::string message); size_t SerializedLength() const override; - size_t Serialize(uint8_t* buffer) const override; + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override; }; struct AcknowledgementPacket : public Packet @@ -61,7 +57,7 @@ namespace Net::Tftp AcknowledgementPacket(); AcknowledgementPacket(uint16_t blockNumber); size_t SerializedLength() const override; - size_t Serialize(uint8_t* buffer) const override; + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override; }; struct DataPacket : public Packet @@ -70,15 +66,16 @@ namespace Net::Tftp std::vector data; DataPacket(); - size_t Serialize(uint8_t* buffer) const override; - static size_t Deserialize( - DataPacket& out, const uint8_t* buffer, size_t length); + size_t SerializedLength() const override; + size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override; + size_t Deserialize(const uint8_t* buffer, const size_t bufferSize); }; void HandlePacket( const Ethernet::Header ethernetReqHeader, const Ipv4::Header ipv4ReqHeader, const Udp::Header udpReqHeader, - const uint8_t* buffer + const uint8_t* data, + const size_t dataSize ); } // namespace Net::Tftp diff --git a/src/net-udp.cpp b/src/net-udp.cpp index 777ac5b..853865c 100644 --- a/src/net-udp.cpp +++ b/src/net-udp.cpp @@ -2,6 +2,8 @@ #include "net-dhcp.h" #include "net-tftp.h" +#include "debug.h" + namespace Net::Udp { Header::Header() @@ -18,8 +20,13 @@ namespace Net::Udp checksum(0) {} - size_t Header::Serialize(uint8_t* buffer) const + size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const { + if (bufferSize < SerializedLength()) + { + return 0; + } + size_t i = 0; buffer[i++] = static_cast(sourcePort) >> 8; buffer[i++] = static_cast(sourcePort); @@ -32,30 +39,51 @@ namespace Net::Udp return i; } - Header Header::Deserialize(const uint8_t* buffer) + size_t Header::Deserialize(const uint8_t* buffer, const size_t bufferSize) { - Header self; - self.sourcePort = static_cast(buffer[0] << 8 | buffer[1]); - self.destinationPort = static_cast(buffer[2] << 8 | buffer[3]); - self.length = buffer[4] << 8 | buffer[5]; - self.checksum = buffer[6] << 8 | buffer[7]; - return self; + if (bufferSize < Header::SerializedLength()) + { + return 0; + } + + sourcePort = static_cast(buffer[0] << 8 | buffer[1]); + destinationPort = static_cast(buffer[2] << 8 | buffer[3]); + length = buffer[4] << 8 | buffer[5]; + checksum = buffer[6] << 8 | buffer[7]; + return 8; } void HandlePacket( const Ethernet::Header ethernetHeader, const Ipv4::Header ipv4Header, const uint8_t* buffer, - const size_t size + const size_t bufferSize ) { - const auto udpHeader = Header::Deserialize(buffer); + Header udpHeader; + const auto headerSize = udpHeader.Deserialize(buffer, bufferSize); + if (headerSize == 0 || headerSize != udpHeader.SerializedLength()) + { + DEBUG_LOG( + "Dropped UDP header (invalid buffer size %lu, expected at least %lu)\r\n", + bufferSize, Header::SerializedLength() + ); + return; + } + if (udpHeader.length <= bufferSize) + { + DEBUG_LOG( + "Dropped UDP packet (invalid buffer size %lu, expected at least %lu)\r\n", + bufferSize, udpHeader.length + ); + return; + } if (udpHeader.destinationPort == Port::DhcpClient) { Dhcp::HandlePacket( ethernetHeader, buffer + udpHeader.SerializedLength(), - size - udpHeader.SerializedLength() + bufferSize - udpHeader.SerializedLength() ); } else if (udpHeader.destinationPort == Port::Tftp) @@ -64,7 +92,8 @@ namespace Net::Udp ethernetHeader, ipv4Header, udpHeader, - buffer + udpHeader.SerializedLength() + buffer + udpHeader.SerializedLength(), + bufferSize - udpHeader.SerializedLength() ); } } diff --git a/src/net-udp.h b/src/net-udp.h index b0941c4..95bb588 100644 --- a/src/net-udp.h +++ b/src/net-udp.h @@ -33,8 +33,8 @@ namespace Net::Udp sizeof(checksum); } - size_t Serialize(uint8_t* buffer) const; - static Header Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t size) const; + size_t Deserialize(const uint8_t* buffer, const size_t size); }; void HandlePacket(