Fix TFTP packets being dropped
This commit is contained in:
parent
f26febe13d
commit
81ea9e27ca
3 changed files with 51 additions and 14 deletions
|
@ -111,11 +111,25 @@ namespace Net::Ipv4
|
|||
if (headerSize != Header::SerializedLength())
|
||||
{
|
||||
DEBUG_LOG(
|
||||
"Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n",
|
||||
"Dropped IPv4 header (invalid buffer size %u, expected at least %u)\r\n",
|
||||
bufferSize,
|
||||
headerSize);
|
||||
return;
|
||||
}
|
||||
DEBUG_LOG(
|
||||
"IPv4 { src=%08lx, dst=%08lx, len=%u, protocol=%u }\r\n",
|
||||
header.sourceIp,
|
||||
header.destinationIp,
|
||||
header.totalLength,
|
||||
static_cast<uint8_t>(header.protocol));
|
||||
if (bufferSize < header.totalLength)
|
||||
{
|
||||
DEBUG_LOG(
|
||||
"Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n",
|
||||
bufferSize,
|
||||
header.totalLength);
|
||||
return;
|
||||
}
|
||||
|
||||
// Update ARP table
|
||||
Arp::ArpTable.insert(std::make_pair(header.sourceIp, ethernetHeader.macSource));
|
||||
|
|
|
@ -25,15 +25,16 @@ namespace Net::Tftp
|
|||
Packet::Packet(const Opcode opcode) : opcode(opcode) {}
|
||||
|
||||
static std::unique_ptr<Packet>
|
||||
handleTftpWriteRequest(const uint8_t* data, const size_t dataSize)
|
||||
handleTftpWriteRequest(const uint8_t* buffer, const size_t bufferSize)
|
||||
{
|
||||
DEBUG_LOG("Received TFTP write request\r\n");
|
||||
WriteReadRequestPacket packet;
|
||||
const auto size = packet.Deserialize(data, dataSize);
|
||||
const auto size = packet.Deserialize(buffer, bufferSize);
|
||||
if (size == 0)
|
||||
{
|
||||
DEBUG_LOG(
|
||||
"Dropped TFTP packet (invalid buffer size %u, expected at least %u)\r\n",
|
||||
dataSize,
|
||||
bufferSize,
|
||||
sizeof(WriteReadRequestPacket::opcode) + 2);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -88,6 +89,7 @@ namespace Net::Tftp
|
|||
|
||||
static std::unique_ptr<Packet> handleTftpData(const uint8_t* buffer, size_t size)
|
||||
{
|
||||
DEBUG_LOG("Received TFTP data\r\n");
|
||||
DataPacket packet;
|
||||
const auto tftpSize = packet.Deserialize(buffer, size);
|
||||
if (tftpSize == 0)
|
||||
|
@ -134,6 +136,7 @@ namespace Net::Tftp
|
|||
const size_t reqBufferSize)
|
||||
{
|
||||
const auto opcode = static_cast<Opcode>(reqBuffer[0] << 8 | reqBuffer[1]);
|
||||
DEBUG_LOG("Received TFTP %u packet\r\n", static_cast<uint16_t>(opcode));
|
||||
std::unique_ptr<Packet> response;
|
||||
|
||||
const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength();
|
||||
|
@ -160,6 +163,7 @@ namespace Net::Tftp
|
|||
|
||||
if (response != nullptr)
|
||||
{
|
||||
DEBUG_LOG("Sending TFTP response\r\n");
|
||||
Udp::Header udpRespHeader(
|
||||
udpReqHeader.destinationPort,
|
||||
udpReqHeader.sourcePort,
|
||||
|
@ -189,6 +193,10 @@ namespace Net::Tftp
|
|||
|
||||
USPiSendFrame(buffer, size);
|
||||
}
|
||||
else
|
||||
{
|
||||
DEBUG_LOG("TFTP response was nullptr\r\n");
|
||||
}
|
||||
|
||||
// TODO Reboot the Pi when a system file was received
|
||||
}
|
||||
|
@ -226,25 +234,33 @@ namespace Net::Tftp
|
|||
|
||||
size_t WriteReadRequestPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize)
|
||||
{
|
||||
// Can't use SerializedLength here, as it's variable.
|
||||
// Check for each field instead.
|
||||
// Can't use SerializedLength here, as it's variable. Check each field separately instead.
|
||||
size_t i = 0;
|
||||
|
||||
if (sizeof(Opcode) >= bufferSize - i)
|
||||
return 0;
|
||||
opcode = static_cast<Opcode>(buffer[i] << 8 | buffer[i + 1]);
|
||||
i += 2;
|
||||
i += sizeof(Opcode);
|
||||
|
||||
const char* filenameStr = reinterpret_cast<const char*>(buffer + i);
|
||||
if (std::strlen(filenameStr) + 1 >= bufferSize - i)
|
||||
// Check if there's a null terminator within the remaining buffer
|
||||
size_t j;
|
||||
for (j = i; j < bufferSize; j++)
|
||||
if (buffer[j] == 0)
|
||||
break;
|
||||
if (j == bufferSize)
|
||||
return 0;
|
||||
filename = std::string(filenameStr);
|
||||
|
||||
filename = std::string(reinterpret_cast<const char*>(buffer + i));
|
||||
i += filename.size() + 1;
|
||||
|
||||
const char* modeStr = reinterpret_cast<const char*>(buffer + i);
|
||||
if (std::strlen(modeStr) + 1 >= bufferSize - i)
|
||||
// Check if there's a null terminator within the remaining buffer
|
||||
for (j = i; j < bufferSize; j++)
|
||||
if (buffer[j] == 0)
|
||||
break;
|
||||
if (j == bufferSize)
|
||||
return 0;
|
||||
mode = std::string(modeStr);
|
||||
|
||||
mode = std::string(reinterpret_cast<const char*>(buffer + i));
|
||||
i += mode.size() + 1;
|
||||
|
||||
return i;
|
||||
|
|
|
@ -62,7 +62,14 @@ namespace Net::Udp
|
|||
Header::SerializedLength());
|
||||
return;
|
||||
}
|
||||
if (udpHeader.length <= bufferSize)
|
||||
|
||||
DEBUG_LOG(
|
||||
"UDP { src=%u, dst=%u, len=%u, chk=%u }\r\n",
|
||||
static_cast<uint16_t>(udpHeader.sourcePort),
|
||||
static_cast<uint16_t>(udpHeader.destinationPort),
|
||||
udpHeader.length,
|
||||
udpHeader.checksum);
|
||||
if (bufferSize < udpHeader.length)
|
||||
{
|
||||
DEBUG_LOG(
|
||||
"Dropped UDP packet (invalid buffer size %u, expected at least %u)\r\n",
|
||||
|
|
Loading…
Reference in a new issue