More buffer size checks
This commit is contained in:
parent
57c490da6e
commit
adf5172e94
9 changed files with 250 additions and 131 deletions
|
@ -94,7 +94,7 @@ namespace Net::Arp
|
||||||
const auto expectedSize =
|
const auto expectedSize =
|
||||||
ethernetHeader.SerializedLength() + arpPacket.SerializedLength();
|
ethernetHeader.SerializedLength() + arpPacket.SerializedLength();
|
||||||
assert(size == expectedSize);
|
assert(size == expectedSize);
|
||||||
assert(size <= USPI_FRAME_BUFFER_SIZE);
|
assert(size <= sizeof(buffer));
|
||||||
|
|
||||||
USPiSendFrame(buffer, size);
|
USPiSendFrame(buffer, size);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include "net-ipv4.h"
|
#include "net-ipv4.h"
|
||||||
#include "net-ethernet.h"
|
#include "net-ethernet.h"
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
#include "types.h"
|
#include "types.h"
|
||||||
#include <uspi.h>
|
#include <uspi.h>
|
||||||
#include <uspios.h>
|
#include <uspios.h>
|
||||||
|
@ -150,8 +151,8 @@ namespace Net::Dhcp
|
||||||
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
|
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += ipv4Header.Serialize(buffer + size);
|
size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += udpHeader.Serialize(buffer + size);
|
size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
|
|
||||||
const auto expectedSize =
|
const auto expectedSize =
|
||||||
|
@ -202,8 +203,8 @@ namespace Net::Dhcp
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
|
|
||||||
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += ipv4Header.Serialize(buffer + size);
|
size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += udpHeader.Serialize(buffer + size);
|
size += udpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
|
|
||||||
const auto expectedSize =
|
const auto expectedSize =
|
||||||
|
@ -249,26 +250,29 @@ namespace Net::Dhcp
|
||||||
const uint8_t* buffer,
|
const uint8_t* buffer,
|
||||||
size_t size
|
size_t size
|
||||||
) {
|
) {
|
||||||
auto dhcpHeader = Header();
|
Header header;
|
||||||
const auto dhcpSize = Header::Deserialize(dhcpHeader, buffer, size);
|
const auto dhcpSize = Header::Deserialize(header, buffer, size);
|
||||||
if (dhcpSize == 0)
|
if (dhcpSize != Header::SerializedLength())
|
||||||
{
|
{
|
||||||
// TODO log
|
DEBUG_LOG(
|
||||||
|
"Dropped DHCP packet (invalid buffer size %lu, expected %lu)\r\n",
|
||||||
|
size, Header::SerializedLength()
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dhcpHeader.opcode != Opcode::BootReply) return;
|
if (header.opcode != Opcode::BootReply) return;
|
||||||
if (dhcpHeader.hardwareAddressType != 1) return;
|
if (header.hardwareAddressType != 1) return;
|
||||||
if (dhcpHeader.hardwareAddressLength != 6) return;
|
if (header.hardwareAddressLength != 6) return;
|
||||||
if (dhcpHeader.transactionId != transactionId) return;
|
if (header.transactionId != transactionId) return;
|
||||||
|
|
||||||
if (!serverSelected)
|
if (!serverSelected)
|
||||||
{
|
{
|
||||||
handleOfferPacket(ethernetHeader, dhcpHeader);
|
handleOfferPacket(ethernetHeader, header);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
handleAckPacket(ethernetHeader, dhcpHeader);
|
handleAckPacket(ethernetHeader, header);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace Net::Dhcp
|
} // namespace Net::Dhcp
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
#include "net-icmp.h"
|
#include "net-icmp.h"
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
#include "types.h"
|
#include "types.h"
|
||||||
#include <uspi.h>
|
#include <uspi.h>
|
||||||
|
|
||||||
|
@ -96,7 +97,7 @@ namespace Net::Icmp
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
|
|
||||||
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += ipv4Header.Serialize(buffer + size);
|
size += ipv4Header.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
|
|
||||||
|
@ -121,8 +122,15 @@ namespace Net::Icmp
|
||||||
EchoHeader reqEchoHeader;
|
EchoHeader reqEchoHeader;
|
||||||
const auto reqEchoHeaderSize =
|
const auto reqEchoHeaderSize =
|
||||||
Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize);
|
Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize);
|
||||||
assert(reqEchoHeaderSize == reqEchoHeader.SerializedLength());
|
if (reqEchoHeaderSize == 0 || reqBufferSize < reqEchoHeaderSize)
|
||||||
assert(reqEchoHeaderSize <= reqBufferSize);
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped ICMP packet "
|
||||||
|
"(invalid buffer size %ul, expected at least %ul)\r\n",
|
||||||
|
reqBufferSize, EchoHeader::SerializedLength()
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0);
|
const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0);
|
||||||
const Ipv4::Header respIpv4Header(
|
const Ipv4::Header respIpv4Header(
|
||||||
|
@ -148,7 +156,8 @@ namespace Net::Icmp
|
||||||
size_t respSize = 0;
|
size_t respSize = 0;
|
||||||
respSize += respEthernetHeader.Serialize(
|
respSize += respEthernetHeader.Serialize(
|
||||||
respBuffer.data() + respSize, respBuffer.size() - respSize);
|
respBuffer.data() + respSize, respBuffer.size() - respSize);
|
||||||
respSize += respIpv4Header.Serialize(respBuffer.data() + respSize);
|
respSize += respIpv4Header.Serialize(
|
||||||
|
respBuffer.data() + respSize, respBuffer.size() - respSize);
|
||||||
respSize += respIcmpHeader.Serialize(
|
respSize += respIcmpHeader.Serialize(
|
||||||
respBuffer.data() + respSize, respBuffer.size() - respSize);
|
respBuffer.data() + respSize, respBuffer.size() - respSize);
|
||||||
std::memcpy(
|
std::memcpy(
|
||||||
|
@ -178,8 +187,9 @@ namespace Net::Icmp
|
||||||
headerSize += Ethernet::Header::Deserialize(
|
headerSize += Ethernet::Header::Deserialize(
|
||||||
ethernetHeader, buffer + headerSize, bufferSize - headerSize);
|
ethernetHeader, buffer + headerSize, bufferSize - headerSize);
|
||||||
|
|
||||||
const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize);
|
Ipv4::Header ipv4Header;
|
||||||
headerSize += ipv4Header.SerializedLength();
|
headerSize += Ipv4::Header::Deserialize(
|
||||||
|
ipv4Header, buffer + headerSize, bufferSize - headerSize);
|
||||||
|
|
||||||
Header icmpHeader;
|
Header icmpHeader;
|
||||||
headerSize += Icmp::Header::Deserialize(
|
headerSize += Icmp::Header::Deserialize(
|
||||||
|
@ -189,9 +199,16 @@ namespace Net::Icmp
|
||||||
ethernetHeader.SerializedLength() +
|
ethernetHeader.SerializedLength() +
|
||||||
ipv4Header.SerializedLength() +
|
ipv4Header.SerializedLength() +
|
||||||
icmpHeader.SerializedLength();
|
icmpHeader.SerializedLength();
|
||||||
assert(headerSize == expectedHeaderSize);
|
if (headerSize != expectedHeaderSize)
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped ICMP packet "
|
||||||
|
"(invalid buffer size %ul, expected at least %ul)\r\n",
|
||||||
|
bufferSize, expectedHeaderSize
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (icmpHeader.type == Icmp::Type::EchoRequest)
|
if (icmpHeader.type == Type::EchoRequest)
|
||||||
{
|
{
|
||||||
handleEchoRequest(
|
handleEchoRequest(
|
||||||
ethernetHeader,
|
ethernetHeader,
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include "net-ipv4.h"
|
#include "net-ipv4.h"
|
||||||
#include "net-ethernet.h"
|
#include "net-ethernet.h"
|
||||||
#include "net-arp.h"
|
#include "net-arp.h"
|
||||||
|
@ -5,6 +7,8 @@
|
||||||
#include "net-udp.h"
|
#include "net-udp.h"
|
||||||
#include "net-utils.h"
|
#include "net-utils.h"
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
|
|
||||||
namespace Net::Ipv4
|
namespace Net::Ipv4
|
||||||
{
|
{
|
||||||
Header::Header() {}
|
Header::Header() {}
|
||||||
|
@ -27,10 +31,14 @@ namespace Net::Ipv4
|
||||||
destinationIp(destinationIp)
|
destinationIp(destinationIp)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
size_t Header::Serialize(uint8_t* buffer) const
|
size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const
|
||||||
{
|
{
|
||||||
size_t i = 0;
|
if (bufferSize <= SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
buffer[i++] = version << 4 | ihl;
|
buffer[i++] = version << 4 | ihl;
|
||||||
buffer[i++] = dscp << 2 | ecn;
|
buffer[i++] = dscp << 2 | ecn;
|
||||||
buffer[i++] = totalLength >> 8;
|
buffer[i++] = totalLength >> 8;
|
||||||
|
@ -62,30 +70,35 @@ namespace Net::Ipv4
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
Header Header::Deserialize(const uint8_t* buffer)
|
size_t Header::Deserialize(
|
||||||
{
|
Header& out, const uint8_t* buffer, const size_t bufferSize
|
||||||
Header self;
|
) {
|
||||||
self.version = buffer[0] >> 4;
|
if (bufferSize <= SerializedLength())
|
||||||
self.ihl = buffer[0] & 0x0F;
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
self.dscp = buffer[1] >> 2;
|
out.version = buffer[0] >> 4;
|
||||||
self.ecn = buffer[1] & 0x03;
|
out.ihl = buffer[0] & 0x0F;
|
||||||
|
|
||||||
self.totalLength = buffer[2] << 8 | buffer[3];
|
out.dscp = buffer[1] >> 2;
|
||||||
self.identification = buffer[4] << 8 | buffer[5];
|
out.ecn = buffer[1] & 0x03;
|
||||||
|
|
||||||
self.flags = buffer[6] >> 5;
|
out.totalLength = buffer[2] << 8 | buffer[3];
|
||||||
self.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7];
|
out.identification = buffer[4] << 8 | buffer[5];
|
||||||
|
|
||||||
self.ttl = buffer[8];
|
out.flags = buffer[6] >> 5;
|
||||||
self.protocol = static_cast<Protocol>(buffer[9]);
|
out.fragmentOffset = (buffer[6] & 0x1F) << 8 | buffer[7];
|
||||||
self.headerChecksum = buffer[10] << 8 | buffer[11];
|
|
||||||
|
|
||||||
self.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15];
|
out.ttl = buffer[8];
|
||||||
self.destinationIp =
|
out.protocol = static_cast<Protocol>(buffer[9]);
|
||||||
|
out.headerChecksum = buffer[10] << 8 | buffer[11];
|
||||||
|
|
||||||
|
out.sourceIp = buffer[12] << 24 | buffer[13] << 16 | buffer[14] << 8 | buffer[15];
|
||||||
|
out.destinationIp =
|
||||||
buffer[16] << 24 | buffer[17] << 16 | buffer[18] << 8 | buffer[19];
|
buffer[16] << 24 | buffer[17] << 16 | buffer[18] << 8 | buffer[19];
|
||||||
|
|
||||||
return self;
|
return 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HandlePacket(
|
void HandlePacket(
|
||||||
|
@ -93,8 +106,16 @@ namespace Net::Ipv4
|
||||||
const uint8_t* buffer,
|
const uint8_t* buffer,
|
||||||
const size_t bufferSize
|
const size_t bufferSize
|
||||||
) {
|
) {
|
||||||
const auto header = Header::Deserialize(buffer);
|
Header header;
|
||||||
const auto headerSize = Header::SerializedLength();
|
const auto headerSize = Header::Deserialize(header, buffer, bufferSize);
|
||||||
|
if (headerSize != Header::SerializedLength())
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped IPv4 packet (invalid buffer size %lu, expected at least %lu)\r\n"
|
||||||
|
bufferSize, headerSize
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Update ARP table
|
// Update ARP table
|
||||||
Arp::ArpTable.insert(
|
Arp::ArpTable.insert(
|
||||||
|
|
|
@ -42,13 +42,14 @@ namespace Net::Ipv4
|
||||||
return 20;
|
return 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Serialize(uint8_t* buffer) const;
|
size_t Serialize(uint8_t* buffer, const size_t bufferSize) const;
|
||||||
static Header Deserialize(const uint8_t* buffer);
|
static size_t Deserialize(
|
||||||
|
Header& out, const uint8_t* buffer, const size_t bufferSize);
|
||||||
};
|
};
|
||||||
|
|
||||||
void HandlePacket(
|
void HandlePacket(
|
||||||
const Ethernet::Header& ethernetHeader,
|
const Ethernet::Header& ethernetHeader,
|
||||||
const uint8_t* buffer,
|
const uint8_t* buffer,
|
||||||
const size_t size
|
const size_t bufferSize
|
||||||
);
|
);
|
||||||
} // namespace Net::Ipv4
|
} // namespace Net::Ipv4
|
||||||
|
|
156
src/net-tftp.cpp
156
src/net-tftp.cpp
|
@ -9,6 +9,7 @@
|
||||||
#include "net-udp.h"
|
#include "net-udp.h"
|
||||||
#include "net.h"
|
#include "net.h"
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
#include "ff.h"
|
#include "ff.h"
|
||||||
#include "types.h"
|
#include "types.h"
|
||||||
#include <uspi.h>
|
#include <uspi.h>
|
||||||
|
@ -20,16 +21,28 @@ namespace Net::Tftp
|
||||||
static bool shouldReboot = false;
|
static bool shouldReboot = false;
|
||||||
static uint32_t currentBlockNumber = -1;
|
static uint32_t currentBlockNumber = -1;
|
||||||
|
|
||||||
static std::unique_ptr<Packet> handleTftpWriteRequest(const uint8_t* data)
|
Packet::Packet() : opcode(static_cast<Opcode>(0)) {}
|
||||||
{
|
Packet::Packet(const Opcode opcode) : opcode(opcode) {}
|
||||||
auto packet = WriteReadRequestPacket::Deserialize(data);
|
|
||||||
|
static std::unique_ptr<Packet> handleTftpWriteRequest(
|
||||||
|
const uint8_t* data, const size_t dataSize
|
||||||
|
) {
|
||||||
|
WriteReadRequestPacket packet;
|
||||||
|
const auto size = packet.Deserialize(data, dataSize);
|
||||||
|
if (size == 0)
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n",
|
||||||
|
dataSize, sizeof(WriteReadRequestPacket::opcode) + 2
|
||||||
|
)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO Implement netscii, maybe
|
// TODO Implement netscii, maybe
|
||||||
if (packet.mode != "octet")
|
if (packet.mode != "octet")
|
||||||
{
|
{
|
||||||
return std::unique_ptr<ErrorPacket>(
|
const auto pointer = new ErrorPacket(0, "please use mode octet");
|
||||||
new ErrorPacket(0, "please use mode octet")
|
return std::unique_ptr<ErrorPacket>(pointer);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
currentBlockNumber = 0;
|
currentBlockNumber = 0;
|
||||||
|
@ -79,10 +92,14 @@ namespace Net::Tftp
|
||||||
static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size)
|
static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size)
|
||||||
{
|
{
|
||||||
DataPacket packet;
|
DataPacket packet;
|
||||||
const auto tftpSize = DataPacket::Deserialize(packet, buffer, size);
|
const auto tftpSize = packet.Deserialize(buffer, size);
|
||||||
if (size == 0)
|
if (tftpSize == 0)
|
||||||
{
|
{
|
||||||
// TODO log
|
DEBUG_LOG(
|
||||||
|
"Dropped TFTP data packet "
|
||||||
|
"(invalid buffer size %lu, expected at least %lu)\r\n",
|
||||||
|
size, sizeof(packet.opcode) + sizeof(packet.blockNumber)
|
||||||
|
)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,20 +137,28 @@ namespace Net::Tftp
|
||||||
const Ethernet::Header ethernetReqHeader,
|
const Ethernet::Header ethernetReqHeader,
|
||||||
const Ipv4::Header ipv4ReqHeader,
|
const Ipv4::Header ipv4ReqHeader,
|
||||||
const Udp::Header udpReqHeader,
|
const Udp::Header udpReqHeader,
|
||||||
const uint8_t* data
|
const uint8_t* reqBuffer,
|
||||||
|
const size_t reqBufferSize
|
||||||
) {
|
) {
|
||||||
const auto opcode = static_cast<Opcode>(data[0] << 8 | data[1]);
|
const auto opcode = static_cast<Opcode>(reqBuffer[0] << 8 | reqBuffer[1]);
|
||||||
std::unique_ptr<Packet> response;
|
std::unique_ptr<Packet> response;
|
||||||
bool last = false;
|
|
||||||
|
const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength();
|
||||||
|
if (reqBufferSize < payloadSize)
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped TFTP packet (invalid buffer size %lu, expected at least %lu)\r\n",
|
||||||
|
reqBufferSize, payloadSize
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (opcode == Opcode::WriteRequest)
|
if (opcode == Opcode::WriteRequest)
|
||||||
{
|
{
|
||||||
response = handleTftpWriteRequest(data);
|
response = handleTftpWriteRequest(reqBuffer, payloadSize);
|
||||||
}
|
}
|
||||||
else if (opcode == Opcode::Data)
|
else if (opcode == Opcode::Data)
|
||||||
{
|
{
|
||||||
const auto length = udpReqHeader.length - Udp::Header::SerializedLength();
|
response = handleTftpData(reqBuffer, payloadSize);
|
||||||
response = handleTftpData(data, length);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -164,9 +189,9 @@ namespace Net::Tftp
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
|
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
|
||||||
size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += ipv4RespHeader.Serialize(buffer + size);
|
size += ipv4RespHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += udpRespHeader.Serialize(buffer + size);
|
size += udpRespHeader.Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
size += response->Serialize(buffer + size);
|
size += response->Serialize(buffer + size, sizeof(buffer) - size);
|
||||||
|
|
||||||
const auto expectedSize =
|
const auto expectedSize =
|
||||||
ethernetRespHeader.SerializedLength() +
|
ethernetRespHeader.SerializedLength() +
|
||||||
|
@ -179,28 +204,29 @@ namespace Net::Tftp
|
||||||
USPiSendFrame(buffer, size);
|
USPiSendFrame(buffer, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (last && shouldReboot)
|
// TODO Reboot the Pi when a system file was received
|
||||||
{
|
|
||||||
// TODO eww
|
|
||||||
extern void Reboot_Pi();
|
|
||||||
Reboot_Pi();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// WriteReadRequestPacket
|
// WriteReadRequestPacket
|
||||||
//
|
//
|
||||||
WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) :
|
WriteReadRequestPacket::WriteReadRequestPacket() : Packet() {}
|
||||||
Packet(opcode)
|
WriteReadRequestPacket::WriteReadRequestPacket(const Opcode opcode) : Packet(opcode)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
size_t WriteReadRequestPacket::SerializedLength() const
|
size_t WriteReadRequestPacket::SerializedLength() const
|
||||||
{
|
{
|
||||||
return Packet::SerializedLength() + filename.size() + 1 + mode.size() + 1;
|
return sizeof(opcode) + filename.size() + 1 + mode.size() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t WriteReadRequestPacket::Serialize(uint8_t* buffer) const
|
size_t WriteReadRequestPacket::Serialize(uint8_t* buffer, const size_t bufferSize)
|
||||||
|
const
|
||||||
{
|
{
|
||||||
|
if (bufferSize < SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode);
|
buffer[i++] = static_cast<uint16_t>(opcode);
|
||||||
|
@ -214,21 +240,28 @@ namespace Net::Tftp
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteReadRequestPacket WriteReadRequestPacket::Deserialize(const uint8_t* buffer)
|
size_t WriteReadRequestPacket::Deserialize(
|
||||||
{
|
const uint8_t* buffer, const size_t bufferSize
|
||||||
|
) {
|
||||||
|
// Can't use SerializedLength here, as it's variable.
|
||||||
|
// Check for each field instead.
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|
||||||
const auto opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]);
|
if (sizeof(Opcode) >= bufferSize - i) return 0;
|
||||||
WriteReadRequestPacket self(opcode);
|
opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]);
|
||||||
i += 2;
|
i += 2;
|
||||||
|
|
||||||
self.filename = reinterpret_cast<const char*>(buffer + i);
|
const char* filenameStr = reinterpret_cast<const char*>(buffer + i);
|
||||||
i += self.filename.size() + 1;
|
if (std::strlen(filenameStr) + 1 >= bufferSize - i) return 0;
|
||||||
|
filename = std::string(filenameStr);
|
||||||
|
i += filename.size() + 1;
|
||||||
|
|
||||||
self.mode = reinterpret_cast<const char*>(buffer + i);
|
const char* modeStr = reinterpret_cast<const char*>(buffer + i);
|
||||||
i += self.mode.size() + 1;
|
if (std::strlen(modeStr) + 1 >= bufferSize - i) return 0;
|
||||||
|
mode = std::string(modeStr);
|
||||||
|
i += mode.size() + 1;
|
||||||
|
|
||||||
return self;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -241,11 +274,16 @@ namespace Net::Tftp
|
||||||
|
|
||||||
size_t ErrorPacket::SerializedLength() const
|
size_t ErrorPacket::SerializedLength() const
|
||||||
{
|
{
|
||||||
return Packet::SerializedLength() + sizeof(errorCode) + message.size() + 1;
|
return sizeof(opcode) + sizeof(errorCode) + message.size() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ErrorPacket::Serialize(uint8_t* buffer) const
|
size_t ErrorPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const
|
||||||
{
|
{
|
||||||
|
if (bufferSize < SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode);
|
buffer[i++] = static_cast<uint16_t>(opcode);
|
||||||
|
@ -261,9 +299,7 @@ namespace Net::Tftp
|
||||||
//
|
//
|
||||||
// AcknowledgementPacket
|
// AcknowledgementPacket
|
||||||
//
|
//
|
||||||
AcknowledgementPacket::AcknowledgementPacket() :
|
AcknowledgementPacket::AcknowledgementPacket() : Packet(Opcode::Acknowledgement) {}
|
||||||
Packet(Opcode::Acknowledgement)
|
|
||||||
{}
|
|
||||||
|
|
||||||
AcknowledgementPacket::AcknowledgementPacket(uint16_t blockNumber) :
|
AcknowledgementPacket::AcknowledgementPacket(uint16_t blockNumber) :
|
||||||
Packet(Opcode::Acknowledgement), blockNumber(blockNumber)
|
Packet(Opcode::Acknowledgement), blockNumber(blockNumber)
|
||||||
|
@ -271,11 +307,17 @@ namespace Net::Tftp
|
||||||
|
|
||||||
size_t AcknowledgementPacket::SerializedLength() const
|
size_t AcknowledgementPacket::SerializedLength() const
|
||||||
{
|
{
|
||||||
return Packet::SerializedLength() + sizeof(blockNumber);
|
return sizeof(opcode) + sizeof(blockNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AcknowledgementPacket::Serialize(uint8_t* buffer) const
|
size_t AcknowledgementPacket::Serialize(uint8_t* buffer, const size_t bufferSize)
|
||||||
|
const
|
||||||
{
|
{
|
||||||
|
if (bufferSize < SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode);
|
buffer[i++] = static_cast<uint16_t>(opcode);
|
||||||
|
@ -290,8 +332,18 @@ namespace Net::Tftp
|
||||||
DataPacket::DataPacket() : Packet(Opcode::Data), blockNumber(0)
|
DataPacket::DataPacket() : Packet(Opcode::Data), blockNumber(0)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
size_t DataPacket::Serialize(uint8_t* buffer) const
|
size_t DataPacket::SerializedLength() const
|
||||||
{
|
{
|
||||||
|
return sizeof(opcode) + sizeof(blockNumber) + data.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t DataPacket::Serialize(uint8_t* buffer, const size_t bufferSize) const
|
||||||
|
{
|
||||||
|
if (bufferSize <= SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
buffer[i++] = static_cast<uint16_t>(opcode) >> 8;
|
||||||
buffer[i++] = static_cast<uint16_t>(opcode);
|
buffer[i++] = static_cast<uint16_t>(opcode);
|
||||||
|
@ -304,17 +356,15 @@ namespace Net::Tftp
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t DataPacket::Deserialize(
|
size_t DataPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize) {
|
||||||
DataPacket& out, const uint8_t* buffer, size_t size
|
if (bufferSize < sizeof(opcode) + sizeof(blockNumber))
|
||||||
) {
|
|
||||||
if (size < sizeof(opcode) + sizeof(blockNumber))
|
|
||||||
{
|
{
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.opcode = static_cast<Opcode>(buffer[0] << 8 | buffer[1]);
|
opcode = static_cast<Opcode>(buffer[0] << 8 | buffer[1]);
|
||||||
out.blockNumber = buffer[2] << 8 | buffer[3];
|
blockNumber = buffer[2] << 8 | buffer[3];
|
||||||
out.data = std::vector<uint8_t>(buffer + 4, buffer + size);
|
data = std::vector<uint8_t>(buffer + 4, buffer + bufferSize);
|
||||||
return size;
|
return bufferSize;
|
||||||
}
|
}
|
||||||
} // namespace Net::Tftp
|
} // namespace Net::Tftp
|
||||||
|
|
|
@ -21,15 +21,10 @@ namespace Net::Tftp
|
||||||
{
|
{
|
||||||
Opcode opcode;
|
Opcode opcode;
|
||||||
|
|
||||||
Packet(Opcode opcode) : opcode(opcode)
|
Packet();
|
||||||
{}
|
Packet(Opcode opcode);
|
||||||
|
virtual size_t SerializedLength() const = 0;
|
||||||
virtual size_t SerializedLength() const
|
virtual size_t Serialize(uint8_t* buffer, const size_t bufferSize) const = 0;
|
||||||
{
|
|
||||||
return sizeof(opcode);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual size_t Serialize(uint8_t* buffer) const = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct WriteReadRequestPacket : public Packet
|
struct WriteReadRequestPacket : public Packet
|
||||||
|
@ -37,10 +32,11 @@ namespace Net::Tftp
|
||||||
std::string filename;
|
std::string filename;
|
||||||
std::string mode;
|
std::string mode;
|
||||||
|
|
||||||
|
WriteReadRequestPacket();
|
||||||
WriteReadRequestPacket(const Opcode opcode);
|
WriteReadRequestPacket(const Opcode opcode);
|
||||||
size_t SerializedLength() const override;
|
size_t SerializedLength() const override;
|
||||||
size_t Serialize(uint8_t* buffer) const override;
|
size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
|
||||||
static WriteReadRequestPacket Deserialize(const uint8_t* buffer);
|
size_t Deserialize(const uint8_t* buffer, const size_t bufferSize);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ErrorPacket : public Packet
|
struct ErrorPacket : public Packet
|
||||||
|
@ -51,7 +47,7 @@ namespace Net::Tftp
|
||||||
ErrorPacket();
|
ErrorPacket();
|
||||||
ErrorPacket(uint16_t errorCode, std::string message);
|
ErrorPacket(uint16_t errorCode, std::string message);
|
||||||
size_t SerializedLength() const override;
|
size_t SerializedLength() const override;
|
||||||
size_t Serialize(uint8_t* buffer) const override;
|
size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AcknowledgementPacket : public Packet
|
struct AcknowledgementPacket : public Packet
|
||||||
|
@ -61,7 +57,7 @@ namespace Net::Tftp
|
||||||
AcknowledgementPacket();
|
AcknowledgementPacket();
|
||||||
AcknowledgementPacket(uint16_t blockNumber);
|
AcknowledgementPacket(uint16_t blockNumber);
|
||||||
size_t SerializedLength() const override;
|
size_t SerializedLength() const override;
|
||||||
size_t Serialize(uint8_t* buffer) const override;
|
size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DataPacket : public Packet
|
struct DataPacket : public Packet
|
||||||
|
@ -70,15 +66,16 @@ namespace Net::Tftp
|
||||||
std::vector<uint8_t> data;
|
std::vector<uint8_t> data;
|
||||||
|
|
||||||
DataPacket();
|
DataPacket();
|
||||||
size_t Serialize(uint8_t* buffer) const override;
|
size_t SerializedLength() const override;
|
||||||
static size_t Deserialize(
|
size_t Serialize(uint8_t* buffer, const size_t bufferSize) const override;
|
||||||
DataPacket& out, const uint8_t* buffer, size_t length);
|
size_t Deserialize(const uint8_t* buffer, const size_t bufferSize);
|
||||||
};
|
};
|
||||||
|
|
||||||
void HandlePacket(
|
void HandlePacket(
|
||||||
const Ethernet::Header ethernetReqHeader,
|
const Ethernet::Header ethernetReqHeader,
|
||||||
const Ipv4::Header ipv4ReqHeader,
|
const Ipv4::Header ipv4ReqHeader,
|
||||||
const Udp::Header udpReqHeader,
|
const Udp::Header udpReqHeader,
|
||||||
const uint8_t* buffer
|
const uint8_t* data,
|
||||||
|
const size_t dataSize
|
||||||
);
|
);
|
||||||
} // namespace Net::Tftp
|
} // namespace Net::Tftp
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
#include "net-dhcp.h"
|
#include "net-dhcp.h"
|
||||||
#include "net-tftp.h"
|
#include "net-tftp.h"
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
|
|
||||||
namespace Net::Udp
|
namespace Net::Udp
|
||||||
{
|
{
|
||||||
Header::Header()
|
Header::Header()
|
||||||
|
@ -18,8 +20,13 @@ namespace Net::Udp
|
||||||
checksum(0)
|
checksum(0)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
size_t Header::Serialize(uint8_t* buffer) const
|
size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const
|
||||||
{
|
{
|
||||||
|
if (bufferSize < SerializedLength())
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
buffer[i++] = static_cast<uint16_t>(sourcePort) >> 8;
|
buffer[i++] = static_cast<uint16_t>(sourcePort) >> 8;
|
||||||
buffer[i++] = static_cast<uint16_t>(sourcePort);
|
buffer[i++] = static_cast<uint16_t>(sourcePort);
|
||||||
|
@ -32,30 +39,51 @@ namespace Net::Udp
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
Header Header::Deserialize(const uint8_t* buffer)
|
size_t Header::Deserialize(const uint8_t* buffer, const size_t bufferSize)
|
||||||
{
|
{
|
||||||
Header self;
|
if (bufferSize < Header::SerializedLength())
|
||||||
self.sourcePort = static_cast<Port>(buffer[0] << 8 | buffer[1]);
|
{
|
||||||
self.destinationPort = static_cast<Port>(buffer[2] << 8 | buffer[3]);
|
return 0;
|
||||||
self.length = buffer[4] << 8 | buffer[5];
|
}
|
||||||
self.checksum = buffer[6] << 8 | buffer[7];
|
|
||||||
return self;
|
sourcePort = static_cast<Port>(buffer[0] << 8 | buffer[1]);
|
||||||
|
destinationPort = static_cast<Port>(buffer[2] << 8 | buffer[3]);
|
||||||
|
length = buffer[4] << 8 | buffer[5];
|
||||||
|
checksum = buffer[6] << 8 | buffer[7];
|
||||||
|
return 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HandlePacket(
|
void HandlePacket(
|
||||||
const Ethernet::Header ethernetHeader,
|
const Ethernet::Header ethernetHeader,
|
||||||
const Ipv4::Header ipv4Header,
|
const Ipv4::Header ipv4Header,
|
||||||
const uint8_t* buffer,
|
const uint8_t* buffer,
|
||||||
const size_t size
|
const size_t bufferSize
|
||||||
) {
|
) {
|
||||||
const auto udpHeader = Header::Deserialize(buffer);
|
Header udpHeader;
|
||||||
|
const auto headerSize = udpHeader.Deserialize(buffer, bufferSize);
|
||||||
|
if (headerSize == 0 || headerSize != udpHeader.SerializedLength())
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped UDP header (invalid buffer size %lu, expected at least %lu)\r\n",
|
||||||
|
bufferSize, Header::SerializedLength()
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (udpHeader.length <= bufferSize)
|
||||||
|
{
|
||||||
|
DEBUG_LOG(
|
||||||
|
"Dropped UDP packet (invalid buffer size %lu, expected at least %lu)\r\n",
|
||||||
|
bufferSize, udpHeader.length
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (udpHeader.destinationPort == Port::DhcpClient)
|
if (udpHeader.destinationPort == Port::DhcpClient)
|
||||||
{
|
{
|
||||||
Dhcp::HandlePacket(
|
Dhcp::HandlePacket(
|
||||||
ethernetHeader,
|
ethernetHeader,
|
||||||
buffer + udpHeader.SerializedLength(),
|
buffer + udpHeader.SerializedLength(),
|
||||||
size - udpHeader.SerializedLength()
|
bufferSize - udpHeader.SerializedLength()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
else if (udpHeader.destinationPort == Port::Tftp)
|
else if (udpHeader.destinationPort == Port::Tftp)
|
||||||
|
@ -64,7 +92,8 @@ namespace Net::Udp
|
||||||
ethernetHeader,
|
ethernetHeader,
|
||||||
ipv4Header,
|
ipv4Header,
|
||||||
udpHeader,
|
udpHeader,
|
||||||
buffer + udpHeader.SerializedLength()
|
buffer + udpHeader.SerializedLength(),
|
||||||
|
bufferSize - udpHeader.SerializedLength()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,8 +33,8 @@ namespace Net::Udp
|
||||||
sizeof(checksum);
|
sizeof(checksum);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Serialize(uint8_t* buffer) const;
|
size_t Serialize(uint8_t* buffer, const size_t size) const;
|
||||||
static Header Deserialize(const uint8_t* buffer);
|
size_t Deserialize(const uint8_t* buffer, const size_t size);
|
||||||
};
|
};
|
||||||
|
|
||||||
void HandlePacket(
|
void HandlePacket(
|
||||||
|
|
Loading…
Reference in a new issue