diff --git a/src/main.cpp b/src/main.cpp index 01fc52f..185baab 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -369,32 +369,34 @@ void InitialiseLCD() void updateNetwork() { - unsigned int size = 0; + unsigned int frameSize = 0; uint8_t ipBuffer[USPI_FRAME_BUFFER_SIZE]; - if (!USPiEthernetAvailable() || !USPiReceiveFrame(ipBuffer, &size)) + if (!USPiEthernetAvailable() || !USPiReceiveFrame(ipBuffer, &frameSize)) { return; } - auto ethernetHeader = Net::Ethernet::Header::Deserialize(ipBuffer); - const auto offset = ethernetHeader.SerializedLength(); + Net::Ethernet::Header ethernetHeader; + auto headerSize = Net::Ethernet::Header::Deserialize( + ethernetHeader, ipBuffer, frameSize); + assert(headerSize != 0); - static bool announcementSent = false; - if (!announcementSent) + static bool arpAnnouncementSent = false; + if (!arpAnnouncementSent) { Net::Arp::SendAnnouncement( Net::Utils::GetMacAddress(), Net::Utils::Ipv4Address); - announcementSent = true; + arpAnnouncementSent = true; } switch (ethernetHeader.type) { case Net::Ethernet::EtherType::Arp: - Net::Arp::HandlePacket(ethernetHeader, ipBuffer + offset); + Net::Arp::HandlePacket(ethernetHeader, ipBuffer + headerSize); break; case Net::Ethernet::EtherType::Ipv4: Net::Ipv4::HandlePacket( - ethernetHeader, ipBuffer + offset, sizeof(ipBuffer) - offset); + ethernetHeader, ipBuffer + headerSize, frameSize - headerSize); break; } } diff --git a/src/net-arp.cpp b/src/net-arp.cpp index 49ea2d5..1af1dcb 100644 --- a/src/net-arp.cpp +++ b/src/net-arp.cpp @@ -1,3 +1,4 @@ +#include #include #include "net-arp.h" @@ -87,8 +88,14 @@ namespace Net::Arp uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; size_t size = 0; - size += ethernetHeader.Serialize(buffer + size); + size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += arpPacket.Serialize(buffer + size); + + const auto expectedSize = + ethernetHeader.SerializedLength() + arpPacket.SerializedLength(); + assert(size == expectedSize); + assert(size <= USPI_FRAME_BUFFER_SIZE); + USPiSendFrame(buffer, size); } diff --git a/src/net-dhcp.cpp b/src/net-dhcp.cpp index b06d1cd..876859d 100644 --- a/src/net-dhcp.cpp +++ b/src/net-dhcp.cpp @@ -149,22 +149,18 @@ namespace Net::Dhcp uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; size_t size = 0; + size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += ipv4Header.Serialize(buffer + size); + size += udpHeader.Serialize(buffer + size); + size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); + const auto expectedSize = ethernetHeader.SerializedLength() + ipv4Header.SerializedLength() + udpHeader.SerializedLength() + dhcpHeader.SerializedLength(); - - size += ethernetHeader.Serialize(buffer + size); - size += ipv4Header.Serialize(buffer + size); - size += udpHeader.Serialize(buffer + size); - size += dhcpHeader.Serialize(buffer + size, USPI_FRAME_BUFFER_SIZE - size); - - if (size != expectedSize) - { - // TODO Log - return; - } + assert(size == expectedSize); + assert(size <= sizeof(buffer)); USPiSendFrame(buffer, size); } @@ -204,22 +200,19 @@ namespace Net::Dhcp uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; size_t size = 0; + + size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += ipv4Header.Serialize(buffer + size); + size += udpHeader.Serialize(buffer + size); + size += dhcpHeader.Serialize(buffer + size, sizeof(buffer) - size); + const auto expectedSize = ethernetHeader.SerializedLength() + ipv4Header.SerializedLength() + udpHeader.SerializedLength() + dhcpHeader.SerializedLength(); - - size += ethernetHeader.Serialize(buffer + size); - size += ipv4Header.Serialize(buffer + size); - size += udpHeader.Serialize(buffer + size); - size += dhcpHeader.Serialize(buffer + size, USPI_FRAME_BUFFER_SIZE - size); - - if (size != expectedSize) - { - // TODO Log - return; - } + assert(size == expectedSize); + assert(size <= sizeof(buffer)); USPiSendFrame(buffer, size); diff --git a/src/net-ethernet.cpp b/src/net-ethernet.cpp index 22084e2..05e382a 100644 --- a/src/net-ethernet.cpp +++ b/src/net-ethernet.cpp @@ -28,8 +28,13 @@ namespace Net::Ethernet type(type) {} - size_t Header::Serialize(uint8_t* buffer) const + size_t Header::Serialize(uint8_t* buffer, const size_t size) const { + if (size < SerializedLength()) + { + return 0; + } + size_t i = 0; std::memcpy(buffer + i, macDestination.data(), macDestination.size()); @@ -44,12 +49,16 @@ namespace Net::Ethernet return i; } - Header Header::Deserialize(const uint8_t* buffer) + size_t Header::Deserialize(Header& out, const uint8_t* buffer, const size_t size) { - Header self; - std::memcpy(self.macDestination.data(), buffer + 0, self.macDestination.size()); - std::memcpy(self.macSource.data(), buffer + 6, self.macSource.size()); - self.type = static_cast(buffer[12] << 8 | buffer[13]); - return self; + if (size < SerializedLength()) + { + return 0; + } + + std::memcpy(out.macDestination.data(), buffer + 0, out.macDestination.size()); + std::memcpy(out.macSource.data(), buffer + 6, out.macSource.size()); + out.type = static_cast(buffer[12] << 8 | buffer[13]); + return 14; } } // namespace Net::Ethernet diff --git a/src/net-ethernet.h b/src/net-ethernet.h index c85a9bb..a169418 100644 --- a/src/net-ethernet.h +++ b/src/net-ethernet.h @@ -28,7 +28,8 @@ namespace Net::Ethernet return sizeof(macDestination) + sizeof(macSource) + sizeof(type); } - size_t Serialize(uint8_t* buffer) const; - static Header Deserialize(const uint8_t* buffer); + size_t Serialize(uint8_t* buffer, const size_t size) const; + static size_t Deserialize( + Header& out, const uint8_t* buffer, const size_t size); }; } // namespace Net::Ethernet diff --git a/src/net-icmp.cpp b/src/net-icmp.cpp index f1be8c7..223f044 100644 --- a/src/net-icmp.cpp +++ b/src/net-icmp.cpp @@ -1,3 +1,4 @@ +#include #include #include "net-icmp.h" @@ -75,64 +76,108 @@ namespace Net::Icmp mac, Utils::GetMacAddress(), Ethernet::EtherType::Ipv4); uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - size_t i = 0; + size_t size = 0; - i += ethernetHeader.Serialize(buffer + i); - i += ipv4Header.Serialize(buffer + i); - i += pingHeader.Serialize(buffer + i); - i += icmpHeader.Serialize(buffer + 1); + size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += ipv4Header.Serialize(buffer + size); + size += pingHeader.Serialize(buffer + size); + size += icmpHeader.Serialize(buffer + 1); - USPiSendFrame(buffer, i); + const auto expectedSize = + ethernetHeader.SerializedLength() + + ipv4Header.SerializedLength() + + pingHeader.SerializedLength() + + icmpHeader.SerializedLength(); + assert(size == expectedSize); + assert(size <= sizeof(buffer)); + + USPiSendFrame(buffer, size); } - void HandlePacket(const uint8_t* buffer) + static void handleEchoRequest( + const Ethernet::Header& reqEthernetHeader, + const Ipv4::Header& reqIpv4Header, + const Icmp::PacketHeader& reqIcmpHeader, + const uint8_t* reqBuffer, + const size_t reqBufferSize + ) { + const auto reqEchoHeader = Icmp::EchoHeader::Deserialize(reqBuffer); + const auto reqHeaderSize = reqEchoHeader.SerializedLength(); + + const Icmp::PacketHeader respIcmpHeader(Icmp::Type::EchoReply, 0); + const Ipv4::Header respIpv4Header( + Ipv4::Protocol::Icmp, + Utils::Ipv4Address, + reqIpv4Header.sourceIp, + reqIpv4Header.totalLength + ); + const Ethernet::Header respEthernetHeader( + reqEthernetHeader.macSource, + Utils::GetMacAddress(), + Ethernet::EtherType::Ipv4 + ); + + const auto payloadLength = + reqIpv4Header.totalLength - + reqIpv4Header.SerializedLength() - + reqIcmpHeader.SerializedLength() - + reqEchoHeader.SerializedLength(); + + std::array respBuffer; + + size_t respSize = 0; + respSize += respEthernetHeader.Serialize( + respBuffer.data() + respSize, respBuffer.size() - respSize); + respSize += respIpv4Header.Serialize(respBuffer.data() + respSize); + respSize += respIcmpHeader.Serialize(respBuffer.data() + respSize); + std::memcpy( + respBuffer.data() + respSize, + reqBuffer + reqHeaderSize, + payloadLength + ); + respSize += payloadLength; + + const auto expectedRespSize = + respEthernetHeader.SerializedLength() + + respIpv4Header.SerializedLength() + + respIcmpHeader.SerializedLength() + + payloadLength; + assert(respSize == expectedRespSize); + assert(respSize <= respBuffer.size()); + + USPiSendFrame(respBuffer.data(), respSize); + } + + void HandlePacket(const uint8_t* buffer, const size_t bufferSize) { // TODO Don't re-parse the upper layers - size_t requestSize = 0; - const auto requestEthernetHeader = - Ethernet::Header::Deserialize(buffer + requestSize); - requestSize += requestEthernetHeader.SerializedLength(); - const auto requestIpv4Header = Ipv4::Header::Deserialize(buffer + requestSize); - requestSize += requestIpv4Header.SerializedLength(); - const auto requestIcmpHeader = - Icmp::PacketHeader::Deserialize(buffer + requestSize); - requestSize += requestIcmpHeader.SerializedLength(); + size_t headerSize = 0; - if (requestIcmpHeader.type == Icmp::Type::EchoRequest) + Ethernet::Header ethernetHeader; + headerSize += Ethernet::Header::Deserialize( + ethernetHeader, buffer + headerSize, bufferSize - headerSize); + + const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize); + headerSize += ipv4Header.SerializedLength(); + + const auto icmpHeader = Icmp::PacketHeader::Deserialize(buffer + headerSize); + headerSize += icmpHeader.SerializedLength(); + + const auto expectedHeaderSize = + ethernetHeader.SerializedLength() + + ipv4Header.SerializedLength() + + icmpHeader.SerializedLength(); + assert(headerSize == expectedHeaderSize); + + if (icmpHeader.type == Icmp::Type::EchoRequest) { - const auto requestEchoHeader = - Icmp::EchoHeader::Deserialize(buffer + requestSize); - requestSize += requestEchoHeader.SerializedLength(); - - const Icmp::PacketHeader responseIcmpHeader( - Icmp::Type::EchoReply, 0); - const Ipv4::Header responseIpv4Header( - Ipv4::Protocol::Icmp, - Utils::Ipv4Address, - requestIpv4Header.sourceIp, - requestIpv4Header.totalLength + handleEchoRequest( + ethernetHeader, + ipv4Header, + icmpHeader, + buffer + headerSize, + bufferSize - headerSize ); - const Ethernet::Header responseEthernetHeader( - requestEthernetHeader.macSource, - Utils::GetMacAddress(), - Ethernet::EtherType::Ipv4 - ); - - const auto payloadLength = - requestIpv4Header.totalLength - - requestIpv4Header.SerializedLength() - - requestIcmpHeader.SerializedLength() - - requestEchoHeader.SerializedLength(); - - std::array bufferResp; - size_t respSize = 0; - respSize += responseEthernetHeader.Serialize(bufferResp.data() + respSize); - respSize += responseIpv4Header.Serialize(bufferResp.data() + respSize); - respSize += responseIcmpHeader.Serialize(bufferResp.data() + respSize); - std::memcpy( - bufferResp.data() + respSize, buffer + requestSize, payloadLength); - respSize += payloadLength; - USPiSendFrame(bufferResp.data(), respSize); } } } // namespace Net::Icmp diff --git a/src/net-icmp.h b/src/net-icmp.h index 57d5c4e..46d1d30 100644 --- a/src/net-icmp.h +++ b/src/net-icmp.h @@ -45,5 +45,5 @@ namespace Net::Icmp }; void SendEchoRequest(Utils::MacAddress mac, uint32_t ip); - void HandlePacket(const uint8_t* buffer); + void HandlePacket(const uint8_t* buffer, const size_t bufferSize); } // namespace Net::Icmp diff --git a/src/net-ipv4.cpp b/src/net-ipv4.cpp index 266f442..f94b160 100644 --- a/src/net-ipv4.cpp +++ b/src/net-ipv4.cpp @@ -91,28 +91,28 @@ namespace Net::Ipv4 void HandlePacket( const Ethernet::Header& ethernetHeader, const uint8_t* buffer, - const size_t size + const size_t bufferSize ) { - const auto ipv4Header = Header::Deserialize(buffer); - const auto offset = Header::SerializedLength(); + const auto header = Header::Deserialize(buffer); + const auto headerSize = Header::SerializedLength(); // Update ARP table Arp::ArpTable.insert( - std::make_pair(ipv4Header.sourceIp, ethernetHeader.macSource)); + std::make_pair(header.sourceIp, ethernetHeader.macSource)); - if (ipv4Header.version != 4) return; - if (ipv4Header.ihl != 5) return; // Not supported - if (ipv4Header.destinationIp != Utils::Ipv4Address) return; - if (ipv4Header.fragmentOffset != 0) return; // TODO Support this + if (header.version != 4) return; + if (header.ihl != 5) return; // Not supported + if (header.destinationIp != Utils::Ipv4Address) return; + if (header.fragmentOffset != 0) return; // TODO Support this - if (ipv4Header.protocol == Ipv4::Protocol::Icmp) + if (header.protocol == Ipv4::Protocol::Icmp) { - Icmp::HandlePacket(buffer); + Icmp::HandlePacket(buffer, bufferSize - headerSize); } - else if (ipv4Header.protocol == Ipv4::Protocol::Udp) + else if (header.protocol == Ipv4::Protocol::Udp) { Udp::HandlePacket( - ethernetHeader, ipv4Header, buffer + offset, size - offset); + ethernetHeader, header, buffer + headerSize, bufferSize - headerSize); } } } // namespace Net::Ipv4 diff --git a/src/net-tftp.cpp b/src/net-tftp.cpp index 4e24971..6b9f4b0 100644 --- a/src/net-tftp.cpp +++ b/src/net-tftp.cpp @@ -1,7 +1,7 @@ #include +#include #include -#include "ff.h" #include "net-arp.h" #include "net-ethernet.h" #include "net-ipv4.h" @@ -9,6 +9,7 @@ #include "net-udp.h" #include "net.h" +#include "ff.h" #include "types.h" #include @@ -160,13 +161,22 @@ namespace Net::Tftp Ethernet::EtherType::Ipv4 ); - size_t i = 0; + size_t size = 0; uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response->Serialize(buffer + i); - USPiSendFrame(buffer, i); + size += ethernetRespHeader.Serialize(buffer + size, sizeof(buffer) - size); + size += ipv4RespHeader.Serialize(buffer + size); + size += udpRespHeader.Serialize(buffer + size); + size += response->Serialize(buffer + size); + + const auto expectedSize = + ethernetRespHeader.SerializedLength() + + ipv4RespHeader.SerializedLength() + + udpRespHeader.SerializedLength() + + response->SerializedLength(); + assert(size == expectedSize); + assert(size <= sizeof(buffer)); + + USPiSendFrame(buffer, size); } if (last && shouldReboot)