Add buffer size checks to Net::Icmp

This commit is contained in:
Sijmen 2020-12-28 18:35:38 +01:00
parent 56cf8cf447
commit 57c490da6e
Signed by: vijfhoek
GPG Key ID: DAF7821E067D9C48
2 changed files with 70 additions and 46 deletions

View File

@ -9,16 +9,18 @@
namespace Net::Icmp namespace Net::Icmp
{ {
// //
// PacketHeader // Header
// //
PacketHeader::PacketHeader() {} Header::Header() : type(static_cast<Type>(0)), code(0), checksum(0) {}
Header::Header(Type type, uint8_t code) : type(type), code(code), checksum(0) {}
PacketHeader::PacketHeader(Type type, uint8_t code) : size_t Header::Serialize(uint8_t* buffer, const size_t bufferSize) const
type(type), code(code), checksum(0)
{}
size_t PacketHeader::Serialize(uint8_t* buffer) const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = static_cast<uint8_t>(type); buffer[i++] = static_cast<uint8_t>(type);
buffer[i++] = code; buffer[i++] = code;
@ -27,13 +29,18 @@ namespace Net::Icmp
return i; return i;
} }
PacketHeader PacketHeader::Deserialize(const uint8_t* buffer) size_t Header::Deserialize(
{ Header& out, const uint8_t* buffer, const size_t bufferSize
PacketHeader self; ) {
self.type = static_cast<Type>(buffer[0]); if (bufferSize < SerializedLength())
self.code = buffer[1]; {
self.checksum = buffer[2] << 8 | buffer[3]; return 0;
return self; }
out.type = static_cast<Type>(buffer[0]);
out.code = buffer[1];
out.checksum = buffer[2] << 8 | buffer[3];
return 4;
} }
// //
@ -43,8 +50,13 @@ namespace Net::Icmp
EchoHeader::EchoHeader(uint16_t identifier, uint16_t sequenceNumber) : EchoHeader::EchoHeader(uint16_t identifier, uint16_t sequenceNumber) :
identifier(identifier), sequenceNumber(sequenceNumber) {} identifier(identifier), sequenceNumber(sequenceNumber) {}
size_t EchoHeader::Serialize(uint8_t* buffer) const size_t EchoHeader::Serialize(uint8_t* buffer, const size_t bufferSize) const
{ {
if (bufferSize < SerializedLength())
{
return 0;
}
size_t i = 0; size_t i = 0;
buffer[i++] = identifier >> 8; buffer[i++] = identifier >> 8;
buffer[i++] = identifier; buffer[i++] = identifier;
@ -53,20 +65,25 @@ namespace Net::Icmp
return i; return i;
} }
EchoHeader EchoHeader::Deserialize(const uint8_t* buffer) size_t EchoHeader::Deserialize(
{ EchoHeader& out, const uint8_t* buffer, const size_t bufferSize
EchoHeader self; ) {
self.identifier = buffer[0] << 8 | buffer[1]; if (bufferSize < SerializedLength())
self.sequenceNumber = buffer[2] << 8 | buffer[3]; {
return self; return 0;
}
out.identifier = buffer[0] << 8 | buffer[1];
out.sequenceNumber = buffer[2] << 8 | buffer[3];
return 4;
} }
void SendEchoRequest(Utils::MacAddress mac, uint32_t ip) void SendEchoRequest(Utils::MacAddress mac, uint32_t ip)
{ {
Icmp::PacketHeader icmpHeader(Icmp::Type::EchoRequest, 0); Icmp::Header icmpHeader(Icmp::Type::EchoRequest, 0);
Icmp::EchoHeader pingHeader(0, 0); Icmp::EchoHeader pingHeader(0, 0);
size_t ipv4TotalSize = Icmp::PacketHeader::SerializedLength() + size_t ipv4TotalSize = Icmp::Header::SerializedLength() +
Icmp::EchoHeader::SerializedLength() + Icmp::EchoHeader::SerializedLength() +
Ipv4::Header::SerializedLength(); Ipv4::Header::SerializedLength();
Ipv4::Header ipv4Header( Ipv4::Header ipv4Header(
@ -80,8 +97,8 @@ namespace Net::Icmp
size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size); size += ethernetHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += ipv4Header.Serialize(buffer + size); size += ipv4Header.Serialize(buffer + size);
size += pingHeader.Serialize(buffer + size); size += pingHeader.Serialize(buffer + size, sizeof(buffer) - size);
size += icmpHeader.Serialize(buffer + 1); size += icmpHeader.Serialize(buffer + size, sizeof(buffer) - size);
const auto expectedSize = const auto expectedSize =
ethernetHeader.SerializedLength() + ethernetHeader.SerializedLength() +
@ -97,14 +114,17 @@ namespace Net::Icmp
static void handleEchoRequest( static void handleEchoRequest(
const Ethernet::Header& reqEthernetHeader, const Ethernet::Header& reqEthernetHeader,
const Ipv4::Header& reqIpv4Header, const Ipv4::Header& reqIpv4Header,
const Icmp::PacketHeader& reqIcmpHeader, const Icmp::Header& reqIcmpHeader,
const uint8_t* reqBuffer, const uint8_t* reqBuffer,
const size_t reqBufferSize const size_t reqBufferSize
) { ) {
const auto reqEchoHeader = Icmp::EchoHeader::Deserialize(reqBuffer); EchoHeader reqEchoHeader;
const auto reqHeaderSize = reqEchoHeader.SerializedLength(); const auto reqEchoHeaderSize =
Icmp::EchoHeader::Deserialize(reqEchoHeader, reqBuffer, reqBufferSize);
assert(reqEchoHeaderSize == reqEchoHeader.SerializedLength());
assert(reqEchoHeaderSize <= reqBufferSize);
const Icmp::PacketHeader respIcmpHeader(Icmp::Type::EchoReply, 0); const Icmp::Header respIcmpHeader(Icmp::Type::EchoReply, 0);
const Ipv4::Header respIpv4Header( const Ipv4::Header respIpv4Header(
Ipv4::Protocol::Icmp, Ipv4::Protocol::Icmp,
Utils::Ipv4Address, Utils::Ipv4Address,
@ -117,11 +137,11 @@ namespace Net::Icmp
Ethernet::EtherType::Ipv4 Ethernet::EtherType::Ipv4
); );
const auto payloadLength = const auto payloadSize =
reqIpv4Header.totalLength - reqIpv4Header.totalLength -
reqIpv4Header.SerializedLength() - reqIpv4Header.SerializedLength() -
reqIcmpHeader.SerializedLength() - reqIcmpHeader.SerializedLength() -
reqEchoHeader.SerializedLength(); reqEchoHeaderSize;
std::array<uint8_t, USPI_FRAME_BUFFER_SIZE> respBuffer; std::array<uint8_t, USPI_FRAME_BUFFER_SIZE> respBuffer;
@ -129,19 +149,20 @@ namespace Net::Icmp
respSize += respEthernetHeader.Serialize( respSize += respEthernetHeader.Serialize(
respBuffer.data() + respSize, respBuffer.size() - respSize); respBuffer.data() + respSize, respBuffer.size() - respSize);
respSize += respIpv4Header.Serialize(respBuffer.data() + respSize); respSize += respIpv4Header.Serialize(respBuffer.data() + respSize);
respSize += respIcmpHeader.Serialize(respBuffer.data() + respSize); respSize += respIcmpHeader.Serialize(
respBuffer.data() + respSize, respBuffer.size() - respSize);
std::memcpy( std::memcpy(
respBuffer.data() + respSize, respBuffer.data() + respSize,
reqBuffer + reqHeaderSize, reqBuffer + reqEchoHeaderSize,
payloadLength payloadSize
); );
respSize += payloadLength; respSize += payloadSize;
const auto expectedRespSize = const auto expectedRespSize =
respEthernetHeader.SerializedLength() + respEthernetHeader.SerializedLength() +
respIpv4Header.SerializedLength() + respIpv4Header.SerializedLength() +
respIcmpHeader.SerializedLength() + respIcmpHeader.SerializedLength() +
payloadLength; payloadSize;
assert(respSize == expectedRespSize); assert(respSize == expectedRespSize);
assert(respSize <= respBuffer.size()); assert(respSize <= respBuffer.size());
@ -160,8 +181,9 @@ namespace Net::Icmp
const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize); const auto ipv4Header = Ipv4::Header::Deserialize(buffer + headerSize);
headerSize += ipv4Header.SerializedLength(); headerSize += ipv4Header.SerializedLength();
const auto icmpHeader = Icmp::PacketHeader::Deserialize(buffer + headerSize); Header icmpHeader;
headerSize += icmpHeader.SerializedLength(); headerSize += Icmp::Header::Deserialize(
icmpHeader, buffer + headerSize, bufferSize - headerSize);
const auto expectedHeaderSize = const auto expectedHeaderSize =
ethernetHeader.SerializedLength() + ethernetHeader.SerializedLength() +

View File

@ -9,22 +9,23 @@ namespace Net::Icmp
EchoRequest = 8, EchoRequest = 8,
}; };
struct PacketHeader struct Header
{ {
Type type; Type type;
uint8_t code; uint8_t code;
uint16_t checksum; uint16_t checksum;
PacketHeader(); Header();
PacketHeader(Type type, uint8_t code); Header(Type type, uint8_t code);
constexpr static size_t SerializedLength() constexpr static size_t SerializedLength()
{ {
return sizeof(type) + sizeof(code) + sizeof(checksum); return sizeof(type) + sizeof(code) + sizeof(checksum);
} }
size_t Serialize(uint8_t* buffer) const; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const;
static PacketHeader Deserialize(const uint8_t* buffer); static size_t Deserialize(
Header& out, const uint8_t* buffer, const size_t bufferSize);
}; };
struct EchoHeader struct EchoHeader
@ -40,10 +41,11 @@ namespace Net::Icmp
return sizeof(identifier) + sizeof(sequenceNumber); return sizeof(identifier) + sizeof(sequenceNumber);
} }
size_t Serialize(uint8_t* buffer) const; size_t Serialize(uint8_t* buffer, const size_t bufferSize) const;
static EchoHeader Deserialize(const uint8_t* buffer); static size_t Deserialize(
EchoHeader& out, const uint8_t* buffer, const size_t bufferSize);
}; };
void SendEchoRequest(Utils::MacAddress mac, uint32_t ip); void SendEchoRequest(const Utils::MacAddress mac, const uint32_t ip);
void HandlePacket(const uint8_t* buffer, const size_t bufferSize); void HandlePacket(const uint8_t* buffer, const size_t bufferSize);
} // namespace Net::Icmp } // namespace Net::Icmp