Fix TFTP packets being dropped

This commit is contained in:
Sijmen 2021-01-06 15:45:14 +01:00
parent f26febe13d
commit 81ea9e27ca
Signed by: vijfhoek
GPG Key ID: DAF7821E067D9C48
3 changed files with 51 additions and 14 deletions

View File

@ -111,11 +111,25 @@ namespace Net::Ipv4
if (headerSize != Header::SerializedLength()) if (headerSize != Header::SerializedLength())
{ {
DEBUG_LOG( DEBUG_LOG(
"Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n", "Dropped IPv4 header (invalid buffer size %u, expected at least %u)\r\n",
bufferSize, bufferSize,
headerSize); headerSize);
return; return;
} }
DEBUG_LOG(
"IPv4 { src=%08lx, dst=%08lx, len=%u, protocol=%u }\r\n",
header.sourceIp,
header.destinationIp,
header.totalLength,
static_cast<uint8_t>(header.protocol));
if (bufferSize < header.totalLength)
{
DEBUG_LOG(
"Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n",
bufferSize,
header.totalLength);
return;
}
// Update ARP table // Update ARP table
Arp::ArpTable.insert(std::make_pair(header.sourceIp, ethernetHeader.macSource)); Arp::ArpTable.insert(std::make_pair(header.sourceIp, ethernetHeader.macSource));

View File

@ -25,15 +25,16 @@ namespace Net::Tftp
Packet::Packet(const Opcode opcode) : opcode(opcode) {} Packet::Packet(const Opcode opcode) : opcode(opcode) {}
static std::unique_ptr<Packet> static std::unique_ptr<Packet>
handleTftpWriteRequest(const uint8_t* data, const size_t dataSize) handleTftpWriteRequest(const uint8_t* buffer, const size_t bufferSize)
{ {
DEBUG_LOG("Received TFTP write request\r\n");
WriteReadRequestPacket packet; WriteReadRequestPacket packet;
const auto size = packet.Deserialize(data, dataSize); const auto size = packet.Deserialize(buffer, bufferSize);
if (size == 0) if (size == 0)
{ {
DEBUG_LOG( DEBUG_LOG(
"Dropped TFTP packet (invalid buffer size %u, expected at least %u)\r\n", "Dropped TFTP packet (invalid buffer size %u, expected at least %u)\r\n",
dataSize, bufferSize,
sizeof(WriteReadRequestPacket::opcode) + 2); sizeof(WriteReadRequestPacket::opcode) + 2);
return nullptr; return nullptr;
} }
@ -88,6 +89,7 @@ namespace Net::Tftp
static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size) static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size)
{ {
DEBUG_LOG("Received TFTP data\r\n");
DataPacket packet; DataPacket packet;
const auto tftpSize = packet.Deserialize(buffer, size); const auto tftpSize = packet.Deserialize(buffer, size);
if (tftpSize == 0) if (tftpSize == 0)
@ -134,6 +136,7 @@ namespace Net::Tftp
const size_t reqBufferSize) const size_t reqBufferSize)
{ {
const auto opcode = static_cast<Opcode>(reqBuffer[0] << 8 | reqBuffer[1]); const auto opcode = static_cast<Opcode>(reqBuffer[0] << 8 | reqBuffer[1]);
DEBUG_LOG("Received TFTP %u packet\r\n", static_cast<uint16_t>(opcode));
std::unique_ptr<Packet> response; std::unique_ptr<Packet> response;
const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength(); const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength();
@ -160,6 +163,7 @@ namespace Net::Tftp
if (response != nullptr) if (response != nullptr)
{ {
DEBUG_LOG("Sending TFTP response\r\n");
Udp::Header udpRespHeader( Udp::Header udpRespHeader(
udpReqHeader.destinationPort, udpReqHeader.destinationPort,
udpReqHeader.sourcePort, udpReqHeader.sourcePort,
@ -189,6 +193,10 @@ namespace Net::Tftp
USPiSendFrame(buffer, size); USPiSendFrame(buffer, size);
} }
else
{
DEBUG_LOG("TFTP response was nullptr\r\n");
}
// TODO Reboot the Pi when a system file was received // TODO Reboot the Pi when a system file was received
} }
@ -226,25 +234,33 @@ namespace Net::Tftp
size_t WriteReadRequestPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize) size_t WriteReadRequestPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize)
{ {
// Can't use SerializedLength here, as it's variable. // Can't use SerializedLength here, as it's variable. Check each field separately instead.
// Check for each field instead.
size_t i = 0; size_t i = 0;
if (sizeof(Opcode) >= bufferSize - i) if (sizeof(Opcode) >= bufferSize - i)
return 0; return 0;
opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]); opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]);
i += 2; i += sizeof(Opcode);
const char* filenameStr = reinterpret_cast<const char*>(buffer + i); // Check if there's a null terminator within the remaining buffer
if (std::strlen(filenameStr) + 1 >= bufferSize - i) size_t j;
for (j = i; j < bufferSize; j++)
if (buffer[j] == 0)
break;
if (j == bufferSize)
return 0; return 0;
filename = std::string(filenameStr);
filename = std::string(reinterpret_cast<const char*>(buffer + i));
i += filename.size() + 1; i += filename.size() + 1;
const char* modeStr = reinterpret_cast<const char*>(buffer + i); // Check if there's a null terminator within the remaining buffer
if (std::strlen(modeStr) + 1 >= bufferSize - i) for (j = i; j < bufferSize; j++)
if (buffer[j] == 0)
break;
if (j == bufferSize)
return 0; return 0;
mode = std::string(modeStr);
mode = std::string(reinterpret_cast<const char*>(buffer + i));
i += mode.size() + 1; i += mode.size() + 1;
return i; return i;

View File

@ -62,7 +62,14 @@ namespace Net::Udp
Header::SerializedLength()); Header::SerializedLength());
return; return;
} }
if (udpHeader.length <= bufferSize)
DEBUG_LOG(
"UDP { src=%u, dst=%u, len=%u, chk=%u }\r\n",
static_cast<uint16_t>(udpHeader.sourcePort),
static_cast<uint16_t>(udpHeader.destinationPort),
udpHeader.length,
udpHeader.checksum);
if (bufferSize < udpHeader.length)
{ {
DEBUG_LOG( DEBUG_LOG(
"Dropped UDP packet (invalid buffer size %u, expected at least %u)\r\n", "Dropped UDP packet (invalid buffer size %u, expected at least %u)\r\n",