diff --git a/src/net-ipv4.cpp b/src/net-ipv4.cpp index 00950c3..d9af4d0 100644 --- a/src/net-ipv4.cpp +++ b/src/net-ipv4.cpp @@ -111,11 +111,25 @@ namespace Net::Ipv4 if (headerSize != Header::SerializedLength()) { DEBUG_LOG( - "Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n", + "Dropped IPv4 header (invalid buffer size %u, expected at least %u)\r\n", bufferSize, headerSize); return; } + DEBUG_LOG( + "IPv4 { src=%08lx, dst=%08lx, len=%u, protocol=%u }\r\n", + header.sourceIp, + header.destinationIp, + header.totalLength, + static_cast(header.protocol)); + if (bufferSize < header.totalLength) + { + DEBUG_LOG( + "Dropped IPv4 packet (invalid buffer size %u, expected at least %u)\r\n", + bufferSize, + header.totalLength); + return; + } // Update ARP table Arp::ArpTable.insert(std::make_pair(header.sourceIp, ethernetHeader.macSource)); diff --git a/src/net-tftp.cpp b/src/net-tftp.cpp index f3fe6c0..64c2afa 100644 --- a/src/net-tftp.cpp +++ b/src/net-tftp.cpp @@ -25,15 +25,16 @@ namespace Net::Tftp Packet::Packet(const Opcode opcode) : opcode(opcode) {} static std::unique_ptr - handleTftpWriteRequest(const uint8_t* data, const size_t dataSize) + handleTftpWriteRequest(const uint8_t* buffer, const size_t bufferSize) { + DEBUG_LOG("Received TFTP write request\r\n"); WriteReadRequestPacket packet; - const auto size = packet.Deserialize(data, dataSize); + const auto size = packet.Deserialize(buffer, bufferSize); if (size == 0) { DEBUG_LOG( "Dropped TFTP packet (invalid buffer size %u, expected at least %u)\r\n", - dataSize, + bufferSize, sizeof(WriteReadRequestPacket::opcode) + 2); return nullptr; } @@ -88,6 +89,7 @@ namespace Net::Tftp static std::unique_ptr handleTftpData(const uint8_t* buffer, size_t size) { + DEBUG_LOG("Received TFTP data\r\n"); DataPacket packet; const auto tftpSize = packet.Deserialize(buffer, size); if (tftpSize == 0) @@ -134,6 +136,7 @@ namespace Net::Tftp const size_t reqBufferSize) { const auto opcode = static_cast(reqBuffer[0] << 8 | reqBuffer[1]); + DEBUG_LOG("Received TFTP %u packet\r\n", static_cast(opcode)); std::unique_ptr response; const auto payloadSize = udpReqHeader.length - udpReqHeader.SerializedLength(); @@ -160,6 +163,7 @@ namespace Net::Tftp if (response != nullptr) { + DEBUG_LOG("Sending TFTP response\r\n"); Udp::Header udpRespHeader( udpReqHeader.destinationPort, udpReqHeader.sourcePort, @@ -189,6 +193,10 @@ namespace Net::Tftp USPiSendFrame(buffer, size); } + else + { + DEBUG_LOG("TFTP response was nullptr\r\n"); + } // TODO Reboot the Pi when a system file was received } @@ -226,25 +234,33 @@ namespace Net::Tftp size_t WriteReadRequestPacket::Deserialize(const uint8_t* buffer, const size_t bufferSize) { - // Can't use SerializedLength here, as it's variable. - // Check for each field instead. + // Can't use SerializedLength here, as it's variable. Check each field separately instead. size_t i = 0; if (sizeof(Opcode) >= bufferSize - i) return 0; opcode = static_cast(buffer[i] << 8 | buffer[i + 1]); - i += 2; + i += sizeof(Opcode); - const char* filenameStr = reinterpret_cast(buffer + i); - if (std::strlen(filenameStr) + 1 >= bufferSize - i) + // Check if there's a null terminator within the remaining buffer + size_t j; + for (j = i; j < bufferSize; j++) + if (buffer[j] == 0) + break; + if (j == bufferSize) return 0; - filename = std::string(filenameStr); + + filename = std::string(reinterpret_cast(buffer + i)); i += filename.size() + 1; - const char* modeStr = reinterpret_cast(buffer + i); - if (std::strlen(modeStr) + 1 >= bufferSize - i) + // Check if there's a null terminator within the remaining buffer + for (j = i; j < bufferSize; j++) + if (buffer[j] == 0) + break; + if (j == bufferSize) return 0; - mode = std::string(modeStr); + + mode = std::string(reinterpret_cast(buffer + i)); i += mode.size() + 1; return i; diff --git a/src/net-udp.cpp b/src/net-udp.cpp index fb60895..15d30a1 100644 --- a/src/net-udp.cpp +++ b/src/net-udp.cpp @@ -62,7 +62,14 @@ namespace Net::Udp Header::SerializedLength()); return; } - if (udpHeader.length <= bufferSize) + + DEBUG_LOG( + "UDP { src=%u, dst=%u, len=%u, chk=%u }\r\n", + static_cast(udpHeader.sourcePort), + static_cast(udpHeader.destinationPort), + udpHeader.length, + udpHeader.checksum); + if (bufferSize < udpHeader.length) { DEBUG_LOG( "Dropped UDP packet (invalid buffer size %u, expected at least %u)\r\n",