More buffer size checks

This commit is contained in:
Sijmen 2020-12-28 23:44:39 +01:00
parent 57c490da6e
commit adf5172e94
Signed by: vijfhoek
GPG key ID: DAF7821E067D9C48
9 changed files with 250 additions and 131 deletions

View file

@ -94,7 +94,7 @@ namespace Net::Arp
const auto expectedSize = const auto expectedSize =
ethernetHeader.SerializedLength() + arpPacket.SerializedLength(); ethernetHeader.SerializedLength() + arpPacket.SerializedLength();
assert(size == expectedSize); assert(size == expectedSize);
assert(size <= USPI_FRAME_BUFFER_SIZE); assert(size <= sizeof(buffer));
USPiSendFrame(buffer, size); USPiSendFrame(buffer, size);
} }

View file

@ -7,6 +7,7 @@
#include "net-ipv4.h" #include "net-ipv4.h"
#include "net-ethernet.h" #include "net-ethernet.h"
#include "debug.h"
#include "types.h" #include "types.h"
#include <uspi.h> #include <uspi.h>
#include <uspios.h> #include <uspios.h>
@ -150,8 +151,8 @@ namespace Net::Dhcp
uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
size_t size = 0; size_t size = 0;
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += ipv4Header.Serialize(buffer + size); size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
size += udpHeader.Serialize(buffer + size); size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
const auto expectedSize = const auto expectedSize =
@ -202,8 +203,8 @@ namespace Net::Dhcp
size_t size = 0; size_t size = 0;
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += ipv4Header.Serialize(buffer + size); size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
size += udpHeader.Serialize(buffer + size); size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
const auto expectedSize = const auto expectedSize =
@ -249,26 +250,29 @@ namespace Net::Dhcp
const uint8_t* buffer, const uint8_t* buffer,
size_t size size_t size
) { ) {
auto dhcpHeader = Header(); Header header;
const auto dhcpSize = Header::Deserialize(dhcpHeader, buffer, size); const auto dhcpSize = Header::Deserialize(header, buffer, size);
if (dhcpSize == 0) if (dhcpSize != Header::SerializedLength())
{ {
// TODO log DEBUG_LOG(
"Dropped DHCP packet (invalid buffer size %lu, expected %lu)\r\n",
size, Header::SerializedLength()
);
return; return;
} }
if (dhcpHeader.opcode != Opcode::BootReply) return; if (header.opcode != Opcode::BootReply) return;
if (dhcpHeader.hardwareAddressType != 1) return; if (header.hardwareAddressType != 1) return;
if (dhcpHeader.hardwareAddressLength != 6) return; if (header.hardwareAddressLength != 6) return;
if (dhcpHeader.transactionId != transactionId) return; if (header.transactionId != transactionId) return;
if (!serverSelected) if (!serverSelected)
{ {
handleOfferPacket(ethernetHeader, dhcpHeader); handleOfferPacket(ethernetHeader, header);
} }
else else
{ {
handleAckPacket(ethernetHeader, dhcpHeader); handleAckPacket(ethernetHeader, header);
} }
} }
} // namespace Net::Dhcp } // namespace Net::Dhcp

View file

@ -3,6 +3,7 @@
#include "net-icmp.h" #include "net-icmp.h"
#include "debug.h"
#include "types.h" #include "types.h"
#include <uspi.h> #include <uspi.h>
@ -96,7 +97,7 @@ namespace Net::Icmp
size_t size = 0; size_t size = 0;
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += ipv4Header.Serialize(buffer + size); size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size); size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size); size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size);
@ -121,8 +122,15 @@ namespace Net::Icmp
EchoHeader reqEchoHeader; EchoHeader reqEchoHeader;
const auto reqEchoHeaderSize = const auto reqEchoHeaderSize =
Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize); Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize);
assert(reqEchoHeaderSize == reqEchoHeader.SerializedLength()); if (reqEchoHeaderSize == 0 || reqBufferSize < reqEchoHeaderSize)
assert(reqEchoHeaderSize <= reqBufferSize); {
DEBUG_LOG(
"Dropped ICMP packet "
"(invalid buffer size %ul, expected at least %ul)\r\n",
reqBufferSize, EchoHeader::SerializedLength()
);
return;
}
const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0); const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0);
const Ipv4::Header respIpv4Header( const Ipv4::Header respIpv4Header(
@ -148,7 +156,8 @@ namespace Net::Icmp
size_t respSize = 0; size_t respSize = 0;
respSize += respEthernetHeader.Serialize( respSize += respEthernetHeader.Serialize(
respBuffer.data() + respSize, respBuffer.size() - respSize); respBuffer.data() + respSize, respBuffer.size() - respSize);
respSize += respIpv4Header.Serialize(respBuffer.data() + respSize); respSize += respIpv4Header.Serialize(
respBuffer.data() + respSize, respBuffer.size() - respSize);
respSize += respIcmpHeader.Serialize( respSize += respIcmpHeader.Serialize(
respBuffer.data() + respSize, respBuffer.size() - respSize); respBuffer.data() + respSize, respBuffer.size() - respSize);
std::memcpy( std::memcpy(
@ -178,8 +187,9 @@ namespace Net::Icmp
headerSize += Ethernet::Header::Deserialize( headerSize += Ethernet::Header::Deserialize(
ethernetHeader, buffer + headerSize, bufferSize - headerSize); ethernetHeader, buffer + headerSize, bufferSize - headerSize);
const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize); Ipv4::Header ipv4Header;
headerSize += ipv4Header.SerializedLength(); headerSize += Ipv4::Header::Deserialize(
ipv4Header, buffer + headerSize, bufferSize - headerSize);
Header icmpHeader; Header icmpHeader;
headerSize += Icmp::Header::Deserialize( headerSize += Icmp::Header::Deserialize(
@ -189,9 +199,16 @@ namespace Net::Icmp
ethernetHeader.SerializedLength() + ethernetHeader.SerializedLength() +
ipv4Header.SerializedLength() + ipv4Header.SerializedLength() +
icmpHeader.SerializedLength(); icmpHeader.SerializedLength();
assert(headerSize == expectedHeaderSize); if (headerSize != expectedHeaderSize)
{
DEBUG_LOG(
"Dropped ICMP packet "
"(invalid buffer size %ul, expected at least %ul)\r\n",
bufferSize, expectedHeaderSize
);
}
if (icmpHeader.type == Icmp::Type::EchoRequest) if (icmpHeader.type == Type::EchoRequest)
{ {
handleEchoRequest( handleEchoRequest(
ethernetHeader, ethernetHeader,

View file

@ -1,3 +1,5 @@
#include <cassert>
#include "net-ipv4.h" #include "net-ipv4.h"
#include "net-ethernet.h" #include "net-ethernet.h"
#include "net-arp.h" #include "net-arp.h"
@ -5,6 +7,8 @@
#include "net-udp.h" #include "net-udp.h"
#include "net-utils.h" #include "net-utils.h"
#include "debug.h"
namespace Net::Ipv4 namespace Net::Ipv4
{ {
Header::Header() {} Header::Header() {}
@ -27,10 +31,14 @@ namespace Net::Ipv4
destinationIp(destinationIp) destinationIp(destinationIp)
{} {}
size_t Header::Serialize(uint8_t* buffer) const size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const
{ {
size_t i = 0; if (bufferSize <= SerializedLength())
{
return 0;
}
size_t i = 0;
buffer[i++] = version << 4 | ihl; buffer[i++] = version << 4 | ihl;
buffer[i++] = dscp << 2 | ecn; buffer[i++] = dscp << 2 | ecn;
buffer[i++] = totalLength >> 8; buffer[i++] = totalLength >> 8;
@ -62,30 +70,35 @@ namespace Net::Ipv4
return i; return i;
} }
Header Header::Deserialize(const uint8_t* buffer) size_t Header::Deserialize(
{ Header& out, const uint8_t* buffer, const size_t bufferSize
Header self; ) {
self.version = buffer[0] >> 4; if (bufferSize <= SerializedLength())
self.ihl = buffer[0] & 0x0F; {
return 0;
}
self.dscp = buffer[1] >> 2; out.version = buffer[0] >> 4;
self.ecn = buffer[1] & 0x03; out.ihl = buffer[0] & 0x0F;
self.totalLength = buffer[2] << 8 | buffer[3]; out.dscp = buffer[1] >> 2;
self.identification = buffer[4] << 8 | buffer[5]; out.ecn = buffer[1] & 0x03;
self.flags = buffer[6] >> 5; out.totalLength = buffer[2] << 8 | buffer[3];
self.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7]; out.identification = buffer[4] << 8 | buffer[5];
self.ttl = buffer[8]; out.flags = buffer[6] >> 5;
self.protocol = static_cast<Protocol>(buffer[9]); out.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7];
self.headerChecksum = buffer[10] << 8 | buffer[11];
self.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15]; out.ttl = buffer[8];
self.destinationIp = out.protocol = static_cast<Protocol>(buffer[9]);
out.headerChecksum = buffer[10] << 8 | buffer[11];
out.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15];
out.destinationIp =
buffer[16] << 24 | buffer[17] << 16 | buffer[18] << 8 | buffer[19]; buffer[16] << 24 | buffer[17] << 16 | buffer[18] << 8 | buffer[19];
return self; return 20;
} }
void HandlePacket( void HandlePacket(
@ -93,8 +106,16 @@ namespace Net::Ipv4
const uint8_t* buffer, const uint8_t* buffer,
const size_t bufferSize const size_t bufferSize
) { ) {
const auto header = Header::Deserialize(buffer); Header header;
const auto headerSize = Header::SerializedLength(); const auto headerSize = Header::Deserialize(header, buffer, bufferSize);
if (headerSize != Header::SerializedLength())
{
DEBUG_LOG(
"Dropped IPv4 packet (invalid buffer size %lu, expected at least %lu)\r\n"
bufferSize, headerSize
);
return;
}
// Update ARP table // Update ARP table
Arp::ArpTable.insert( Arp::ArpTable.insert(

View file

@ -42,13 +42,14 @@ namespace Net::Ipv4
return 20; return 20;
} }
size_t Serialize(uint8_t* buffer) const; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const;
static Header Deserialize(const uint8_t* buffer); static size_t Deserialize(
Header& out, const uint8_t* buffer, const size_t bufferSize);
}; };
void HandlePacket( void HandlePacket(
const Ethernet::Header& ethernetHeader, const Ethernet::Header& ethernetHeader,
const uint8_t* buffer, const uint8_t* buffer,
const size_t size const size_t bufferSize
); );
} // namespace Net::Ipv4 } // namespace Net::Ipv4

View file

@ -9,6 +9,7 @@
#include "net-udp.h" #include "net-udp.h"
#include "net.h" #include "net.h"
#include "debug.h"
#include "ff.h" #include "ff.h"
#include "types.h" #include "types.h"
#include <uspi.h> #include <uspi.h>
@ -20,16 +21,28 @@ namespace Net::Tftp
static bool shouldReboot = false; static bool shouldReboot = false;
static uint32_t currentBlockNumber = -1; static uint32_t currentBlockNumber = -1;
static std::unique_ptr<Packet> handleTftpWriteRequest(const uint8_t* data) Packet::Packet() : opcode(static_cast<Opcode>(0)) {}
{ Packet::Packet(const Opcode opcode) : opcode(opcode) {}
auto packet = WriteReadRequestPacket::Deserialize(data);
static std::unique_ptr<Packet> handleTftpWriteRequest(
const uint8_t* data, const size_t dataSize
) {
WriteReadRequestPacket packet;
const auto size = packet.Deserialize(data, dataSize);
if (size == 0)
{
DEBUG_LOG(
"Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n",
dataSize, sizeof(WriteReadRequestPacket::opcode) + 2
)
return nullptr;
}
// TODO Implement netscii, maybe // TODO Implement netscii, maybe
if (packet.mode != "octet") if (packet.mode != "octet")
{ {
return std::unique_ptr<ErrorPacket>( const auto pointer = new ErrorPacket(0, "please use mode octet");
new ErrorPacket(0, "please use mode octet") return std::unique_ptr<ErrorPacket>(pointer);
);
} }
currentBlockNumber = 0; currentBlockNumber = 0;
@ -79,10 +92,14 @@ namespace Net::Tftp
static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size) static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size)
{ {
DataPacket packet; DataPacket packet;
const auto tftpSize = DataPacket::Deserialize(packet, buffer, size); const auto tftpSize = packet.Deserialize(buffer, size);
if (size == 0) if (tftpSize == 0)
{ {
// TODO log DEBUG_LOG(
"Dropped TFTP data packet "
"(invalid buffer size %lu, expected at least %lu)\r\n",
size, sizeof(packet.opcode) + sizeof(packet.blockNumber)
)
return nullptr; return nullptr;
} }
@ -120,20 +137,28 @@ namespace Net::Tftp
const Ethernet::Header ethernetReqHeader, const Ethernet::Header ethernetReqHeader,
const Ipv4::Header ipv4ReqHeader, const Ipv4::Header ipv4ReqHeader,
const Udp::Header udpReqHeader, const Udp::Header udpReqHeader,
const uint8_t* data const uint8_t* reqBuffer,
const size_t reqBufferSize
) { ) {
const auto opcode = static_cast<Opcode>(data[0] << 8 | data[1]); const auto opcode = static_cast<Opcode>(reqBuffer[0] << 8 | reqBuffer[1]);
std::unique_ptr<Packet> response; std::unique_ptr<Packet> response;
bool last = false;
const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength();
if (reqBufferSize < payloadSize)
{
DEBUG_LOG(
"Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n",
reqBufferSize, payloadSize
);
}
if (opcode == Opcode::WriteRequest) if (opcode == Opcode::WriteRequest)
{ {
response = handleTftpWriteRequest(data); response = handleTftpWriteRequest(reqBuffer, payloadSize);
} }
else if (opcode == Opcode::Data) else if (opcode == Opcode::Data)
{ {
const auto length = udpReqHeader.length - Udp::Header::SerializedLength(); response = handleTftpData(reqBuffer, payloadSize);
response = handleTftpData(data, length);
} }
else else
{ {
@ -164,9 +189,9 @@ namespace Net::Tftp
size_t size = 0; size_t size = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += ipv4RespHeader.Serialize(buffer + size); size += ipv4RespHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += udpRespHeader.Serialize(buffer + size); size += udpRespHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += response->Serialize(buffer + size); size += response->Serialize(buffer + size, sizeof(buffer) - size);
const auto expectedSize = const auto expectedSize =
ethernetRespHeader.SerializedLength() + ethernetRespHeader.SerializedLength() +
@ -179,28 +204,29 @@ namespace Net::Tftp
USPiSendFrame(buffer, size); USPiSendFrame(buffer, size);
} }
if (last && shouldReboot) // TODO Reboot the Pi when a system file was received
{
// TODO eww
extern void Reboot_Pi();
Reboot_Pi();
}
} }
// //
// WriteReadRequestPacket // WriteReadRequestPacket
// //
WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) : WriteReadRequestPacket::WriteReadRequestPacket() : Packet() {}
Packet(opcode) WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) : Packet(opcode)
{} {}
size_t WriteReadRequestPacket::SerializedLength() const size_t WriteReadRequestPacket::SerializedLength() const
{ {
return Packet::SerializedLength() + filename.size() + 1 + mode.size() + 1; return sizeof(opcode) + filename.size() + 1 + mode.size() + 1;
} }
size_t WriteReadRequestPacket::Serialize(uint8_t* buffer) const size_t WriteReadRequestPacket::Serialize(uint8_t* buffer, const size_t bufferSize)
const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint16_t>(opcode) >> 8; buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
buffer[i++] = static_cast<uint16_t>(opcode); buffer[i++] = static_cast<uint16_t>(opcode);
@ -214,21 +240,28 @@ namespace Net::Tftp
return i; return i;
} }
WriteReadRequestPacket WriteReadRequestPacket::Deserialize(const uint8_t* buffer) size_t WriteReadRequestPacket::Deserialize(
{ const uint8_t* buffer, const size_t bufferSize
) {
// Can't use SerializedLength here, as it's variable.
// Check for each field instead.
size_t i = 0; size_t i = 0;
const auto opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]); if (sizeof(Opcode) >= bufferSize - i) return 0;
WriteReadRequestPacket self(opcode); opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]);
i += 2; i += 2;
self.filename = reinterpret_cast<const char*>(buffer + i); const char* filenameStr = reinterpret_cast<const char*>(buffer + i);
i += self.filename.size() + 1; if (std::strlen(filenameStr) + 1 >= bufferSize - i) return 0;
filename = std::string(filenameStr);
i += filename.size() + 1;
self.mode = reinterpret_cast<const char*>(buffer + i); const char* modeStr = reinterpret_cast<const char*>(buffer + i);
i += self.mode.size() + 1; if (std::strlen(modeStr) + 1 >= bufferSize - i) return 0;
mode = std::string(modeStr);
i += mode.size() + 1;
return self; return i;
} }
// //
@ -241,11 +274,16 @@ namespace Net::Tftp
size_t ErrorPacket::SerializedLength() const size_t ErrorPacket::SerializedLength() const
{ {
return Packet::SerializedLength() + sizeof(errorCode) + message.size() + 1; return sizeof(opcode) + sizeof(errorCode) + message.size() + 1;
} }
size_t ErrorPacket::Serialize(uint8_t* buffer) const size_t ErrorPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint16_t>(opcode) >> 8; buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
buffer[i++] = static_cast<uint16_t>(opcode); buffer[i++] = static_cast<uint16_t>(opcode);
@ -261,9 +299,7 @@ namespace Net::Tftp
// //
// AcknowledgementPacket // AcknowledgementPacket
// //
AcknowledgementPacket::AcknowledgementPacket() : AcknowledgementPacket::AcknowledgementPacket() : Packet(Opcode::Acknowledgement) {}
Packet(Opcode::Acknowledgement)
{}
AcknowledgementPacket::AcknowledgementPacket(uint16_t blockNumber) : AcknowledgementPacket::AcknowledgementPacket(uint16_t blockNumber) :
Packet(Opcode::Acknowledgement), blockNumber(blockNumber) Packet(Opcode::Acknowledgement), blockNumber(blockNumber)
@ -271,11 +307,17 @@ namespace Net::Tftp
size_t AcknowledgementPacket::SerializedLength() const size_t AcknowledgementPacket::SerializedLength() const
{ {
return Packet::SerializedLength() + sizeof(blockNumber); return sizeof(opcode) + sizeof(blockNumber);
} }
size_t AcknowledgementPacket::Serialize(uint8_t* buffer) const size_t AcknowledgementPacket::Serialize(uint8_t* buffer, const size_t bufferSize)
const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint16_t>(opcode) >> 8; buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
buffer[i++] = static_cast<uint16_t>(opcode); buffer[i++] = static_cast<uint16_t>(opcode);
@ -290,8 +332,18 @@ namespace Net::Tftp
DataPacket::DataPacket() : Packet(Opcode::Data), blockNumber(0) DataPacket::DataPacket() : Packet(Opcode::Data), blockNumber(0)
{} {}
size_t DataPacket::Serialize(uint8_t* buffer) const size_t DataPacket::SerializedLength() const
{ {
return sizeof(opcode) + sizeof(blockNumber) + data.size();
}
size_t DataPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const
{
if (bufferSize <= SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint16_t>(opcode) >> 8; buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
buffer[i++] = static_cast<uint16_t>(opcode); buffer[i++] = static_cast<uint16_t>(opcode);
@ -304,17 +356,15 @@ namespace Net::Tftp
return i; return i;
} }
size_t DataPacket::Deserialize( size_t DataPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize) {
DataPacket& out, const uint8_t* buffer, size_t size if (bufferSize < sizeof(opcode) + sizeof(blockNumber))
) {
if (size < sizeof(opcode) + sizeof(blockNumber))
{ {
return 0; return 0;
} }
out.opcode = static_cast<Opcode>(buffer[0] << 8 | buffer[1]); opcode = static_cast<Opcode>(buffer[0] << 8 | buffer[1]);
out.blockNumber = buffer[2] << 8 | buffer[3]; blockNumber = buffer[2] << 8 | buffer[3];
out.data = std::vector<uint8_t>(buffer + 4, buffer + size); data = std::vector<uint8_t>(buffer + 4, buffer + bufferSize);
return size; return bufferSize;
} }
} // namespace Net::Tftp } // namespace Net::Tftp

View file

@ -21,15 +21,10 @@ namespace Net::Tftp
{ {
Opcode opcode; Opcode opcode;
Packet(Opcode opcode) : opcode(opcode) Packet();
{} Packet(Opcode opcode);
virtual size_t SerializedLength() const = 0;
virtual size_t SerializedLength() const virtual size_t Serialize(uint8_t* buffer, const size_t bufferSize) const = 0;
{
return sizeof(opcode);
}
virtual size_t Serialize(uint8_t* buffer) const = 0;
}; };
struct WriteReadRequestPacket : public Packet struct WriteReadRequestPacket : public Packet
@ -37,10 +32,11 @@ namespace Net::Tftp
std::string filename; std::string filename;
std::string mode; std::string mode;
WriteReadRequestPacket();
WriteReadRequestPacket(const Opcode opcode); WriteReadRequestPacket(const Opcode opcode);
size_t SerializedLength() const override; size_t SerializedLength() const override;
size_t Serialize(uint8_t* buffer) const override; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
static WriteReadRequestPacket Deserialize(const uint8_t* buffer); size_t Deserialize(const uint8_t* buffer, const size_t bufferSize);
}; };
struct ErrorPacket : public Packet struct ErrorPacket : public Packet
@ -51,7 +47,7 @@ namespace Net::Tftp
ErrorPacket(); ErrorPacket();
ErrorPacket(uint16_t errorCode, std::string message); ErrorPacket(uint16_t errorCode, std::string message);
size_t SerializedLength() const override; size_t SerializedLength() const override;
size_t Serialize(uint8_t* buffer) const override; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
}; };
struct AcknowledgementPacket : public Packet struct AcknowledgementPacket : public Packet
@ -61,7 +57,7 @@ namespace Net::Tftp
AcknowledgementPacket(); AcknowledgementPacket();
AcknowledgementPacket(uint16_t blockNumber); AcknowledgementPacket(uint16_t blockNumber);
size_t SerializedLength() const override; size_t SerializedLength() const override;
size_t Serialize(uint8_t* buffer) const override; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
}; };
struct DataPacket : public Packet struct DataPacket : public Packet
@ -70,15 +66,16 @@ namespace Net::Tftp
std::vector<uint8_t> data; std::vector<uint8_t> data;
DataPacket(); DataPacket();
size_t Serialize(uint8_t* buffer) const override; size_t SerializedLength() const override;
static size_t Deserialize( size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
DataPacket& out, const uint8_t* buffer, size_t length); size_t Deserialize(const uint8_t* buffer, const size_t bufferSize);
}; };
void HandlePacket( void HandlePacket(
const Ethernet::Header ethernetReqHeader, const Ethernet::Header ethernetReqHeader,
const Ipv4::Header ipv4ReqHeader, const Ipv4::Header ipv4ReqHeader,
const Udp::Header udpReqHeader, const Udp::Header udpReqHeader,
const uint8_t* buffer const uint8_t* data,
const size_t dataSize
); );
} // namespace Net::Tftp } // namespace Net::Tftp

View file

@ -2,6 +2,8 @@
#include "net-dhcp.h" #include "net-dhcp.h"
#include "net-tftp.h" #include "net-tftp.h"
#include "debug.h"
namespace Net::Udp namespace Net::Udp
{ {
Header::Header() Header::Header()
@ -18,8 +20,13 @@ namespace Net::Udp
checksum(0) checksum(0)
{} {}
size_t Header::Serialize(uint8_t* buffer) const size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint16_t>(sourcePort) >> 8; buffer[i++] = static_cast<uint16_t>(sourcePort) >> 8;
buffer[i++] = static_cast<uint16_t>(sourcePort); buffer[i++] = static_cast<uint16_t>(sourcePort);
@ -32,30 +39,51 @@ namespace Net::Udp
return i; return i;
} }
Header Header::Deserialize(const uint8_t* buffer) size_t Header::Deserialize(const uint8_t* buffer, const size_t bufferSize)
{ {
Header self; if (bufferSize < Header::SerializedLength())
self.sourcePort = static_cast<Port>(buffer[0] << 8 | buffer[1]); {
self.destinationPort = static_cast<Port>(buffer[2] << 8 | buffer[3]); return 0;
self.length = buffer[4] << 8 | buffer[5]; }
self.checksum = buffer[6] << 8 | buffer[7];
return self; sourcePort = static_cast<Port>(buffer[0] << 8 | buffer[1]);
destinationPort = static_cast<Port>(buffer[2] << 8 | buffer[3]);
length = buffer[4] << 8 | buffer[5];
checksum = buffer[6] << 8 | buffer[7];
return 8;
} }
void HandlePacket( void HandlePacket(
const Ethernet::Header ethernetHeader, const Ethernet::Header ethernetHeader,
const Ipv4::Header ipv4Header, const Ipv4::Header ipv4Header,
const uint8_t* buffer, const uint8_t* buffer,
const size_t size const size_t bufferSize
) { ) {
const auto udpHeader = Header::Deserialize(buffer); Header udpHeader;
const auto headerSize = udpHeader.Deserialize(buffer, bufferSize);
if (headerSize == 0 || headerSize != udpHeader.SerializedLength())
{
DEBUG_LOG(
"Dropped UDP header (invalid buffer size %lu, expected at least %lu)\r\n",
bufferSize, Header::SerializedLength()
);
return;
}
if (udpHeader.length <= bufferSize)
{
DEBUG_LOG(
"Dropped UDP packet (invalid buffer size %lu, expected at least %lu)\r\n",
bufferSize, udpHeader.length
);
return;
}
if (udpHeader.destinationPort == Port::DhcpClient) if (udpHeader.destinationPort == Port::DhcpClient)
{ {
Dhcp::HandlePacket( Dhcp::HandlePacket(
ethernetHeader, ethernetHeader,
buffer + udpHeader.SerializedLength(), buffer + udpHeader.SerializedLength(),
size - udpHeader.SerializedLength() bufferSize - udpHeader.SerializedLength()
); );
} }
else if (udpHeader.destinationPort == Port::Tftp) else if (udpHeader.destinationPort == Port::Tftp)
@ -64,7 +92,8 @@ namespace Net::Udp
ethernetHeader, ethernetHeader,
ipv4Header, ipv4Header,
udpHeader, udpHeader,
buffer + udpHeader.SerializedLength() buffer + udpHeader.SerializedLength(),
bufferSize - udpHeader.SerializedLength()
); );
} }
} }

View file

@ -33,8 +33,8 @@ namespace Net::Udp
sizeof(checksum); sizeof(checksum);
} }
size_t Serialize(uint8_t* buffer) const; size_t Serialize(uint8_t* buffer, const size_t size) const;
static Header Deserialize(const uint8_t* buffer); size_t Deserialize(const uint8_t* buffer, const size_t size);
}; };
void HandlePacket( void HandlePacket(