From 60e186cdc96f9cc24942671d483d25902d99a5d6 Mon Sep 17 00:00:00 2001 From: Sijmen Schoon Date: Fri, 11 Dec 2020 22:44:10 +0100 Subject: [PATCH] Split TFTP code into readable chunks and deduplicate a bunch of code --- Makefile | 2 +- src/net-arp.h | 1 + src/net-tftp.cpp | 168 ++++++++++++++++++++++++++++++ src/net-tftp.h | 138 ++++++++++++++++++++++++ src/net-udp.h | 136 ++---------------------- src/net.cpp | 265 ++--------------------------------------------- 6 files changed, 323 insertions(+), 387 deletions(-) create mode 100644 src/net-tftp.cpp create mode 100644 src/net-tftp.h diff --git a/Makefile b/Makefile index b3e31b6..658a782 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ OBJS = armc-start.o armc-cstartup.o armc-cstubs.o armc-cppstubs.o \ Drive.o Pi1541.o DiskImage.o iec_bus.o iec_commands.o m6502.o m6522.o \ gcr.o prot.o lz.o emmc.o diskio.o options.o Screen.o SSD1306.o ScreenLCD.o \ Timer.o FileBrowser.o DiskCaddy.o ROMs.o InputMappings.o xga_font_data.o m8520.o wd177x.o Pi1581.o SpinLock.o \ - net.o + net.o net-tftp.o SRCDIR = src OBJS := $(addprefix $(SRCDIR)/, $(OBJS)) diff --git a/src/net-arp.h b/src/net-arp.h index 2244e95..729dcad 100644 --- a/src/net-arp.h +++ b/src/net-arp.h @@ -1,4 +1,5 @@ #pragma once +#include "net.h" struct Ipv4ArpPacket { diff --git a/src/net-tftp.cpp b/src/net-tftp.cpp new file mode 100644 index 0000000..b837e60 --- /dev/null +++ b/src/net-tftp.cpp @@ -0,0 +1,168 @@ +#include + +#include "ff.h" +#include "net-ethernet.h" +#include "net-ipv4.h" +#include "net-tftp.h" +#include "net-udp.h" +#include "net.h" +#include "types.h" + +#include + +// TODO Allow multiple files open +static FIL outFile; +static bool shouldReboot = false; +static uint32_t currentBlockNumber = -1; + +static std::unique_ptr handleTftpWriteRequest(const uint8_t* data) +{ + auto packet = TftpWriteReadRequestPacket::Deserialize(data); + + // TODO Implement netscii, maybe + if (packet.mode != "octet") + { + return std::unique_ptr( + new TftpErrorPacket(0, "please use mode octet") + ); + } + + currentBlockNumber = 0; + + // TODO Return to the original working directory. + char workingDirectory[256]; + f_getcwd(workingDirectory, sizeof(workingDirectory)); + + // Try opening the file + auto separator = packet.filename.rfind('/', packet.filename.size()); + if (separator != std::string::npos) + { + auto path = "/" + packet.filename.substr(0, separator); + f_chdir(path.c_str()); + } + else + { + f_chdir("/"); + separator = 0; + } + + // Open the output file. + auto filename = packet.filename.substr(separator + 1); + const auto result = f_open(&outFile, filename.c_str(), FA_CREATE_ALWAYS | FA_WRITE); + + std::unique_ptr response; + if (result != FR_OK) + { + response = std::unique_ptr( + new TftpErrorPacket(0, "error opening target file") + ); + } + else + { + shouldReboot = + packet.filename == "kernel.img" || packet.filename == "options.txt"; + response = std::unique_ptr( + new TftpAcknowledgementPacket(currentBlockNumber) + ); + } + + // TODO Return to the original working directory here + + return response; +} + +static std::unique_ptr handleTftpData(const uint8_t* data, size_t length) +{ + auto packet = TftpDataPacket::Deserialize(data, length); + + if (packet.blockNumber != currentBlockNumber + 1) + { + f_close(&outFile); + return std::unique_ptr( + new TftpErrorPacket(0, "invalid block number") + ); + } + currentBlockNumber = packet.blockNumber; + + unsigned int bytesWritten; + const auto result = + f_write(&outFile, packet.data.data(), packet.data.size(), &bytesWritten); + + if (result != FR_OK || bytesWritten != packet.data.size()) + { + f_close(&outFile); + return std::unique_ptr(new TftpErrorPacket(0, "io error")); + } + + if (packet.data.size() < 512) + { + // Close the file for the last packet. + f_close(&outFile); + } + + return std::unique_ptr( + new TftpAcknowledgementPacket(currentBlockNumber) + ); +} + +void HandleTftpDatagram( + const EthernetFrameHeader ethernetReqHeader, + const Ipv4Header ipv4ReqHeader, + const UdpDatagramHeader udpReqHeader, + const uint8_t* data +) { + const auto opcode = data[0] << 8 | data[1]; + std::unique_ptr response; + bool last = false; + + if (opcode == TFTP_OP_WRITE_REQUEST) + { + response = handleTftpWriteRequest(data); + } + else if (opcode == TFTP_OP_DATA) + { + const auto length = udpReqHeader.length - UdpDatagramHeader::SerializedLength(); + response = handleTftpData(data, length); + } + else + { + response = std::unique_ptr( + new TftpErrorPacket(4, "not implemented yet") + ); + } + + if (response != nullptr) + { + UdpDatagramHeader udpRespHeader( + udpReqHeader.destinationPort, + udpReqHeader.sourcePort, + response->SerializedLength() + UdpDatagramHeader::SerializedLength() + ); + Ipv4Header ipv4RespHeader( + IP_PROTO_UDP, + Ipv4Address, + ipv4ReqHeader.sourceIp, + udpRespHeader.length + Ipv4Header::SerializedLength() + ); + EthernetFrameHeader ethernetRespHeader( + ArpTable[ipv4RespHeader.destinationIp], + GetMacAddress(), + ETHERTYPE_IPV4 + ); + + size_t i = 0; + uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; + i += ethernetRespHeader.Serialize(buffer + i); + i += ipv4RespHeader.Serialize(buffer + i); + i += udpRespHeader.Serialize(buffer + i); + i += response->Serialize(buffer + i); + USPiSendFrame(buffer, i); + } + + if (last && shouldReboot) + { + // TODO eww + extern void Reboot_Pi(); + Reboot_Pi(); + } +} diff --git a/src/net-tftp.h b/src/net-tftp.h new file mode 100644 index 0000000..31e68ed --- /dev/null +++ b/src/net-tftp.h @@ -0,0 +1,138 @@ +#pragma once +#include + +const size_t TFTP_BLOCK_SIZE = 512; + +enum TftpOperation +{ + TFTP_OP_READ_REQUEST = 1, + TFTP_OP_WRITE_REQUEST = 2, + TFTP_OP_DATA = 3, + TFTP_OP_ACKNOWLEDGEMENT = 4, + TFTP_OP_ERROR = 5, +}; + +struct TftpPacket +{ + uint16_t opcode; + + TftpPacket(uint16_t opcode) : opcode(opcode) {} + + virtual size_t SerializedLength() const = 0; + virtual size_t Serialize(uint8_t* buffer) const = 0; +}; + +struct TftpWriteReadRequestPacket : public TftpPacket +{ + std::string filename; + std::string mode; + + TftpWriteReadRequestPacket(uint16_t opcode) : TftpPacket(opcode) {} + + constexpr size_t SerializedLength() override { + return sizeof(opcode) + filename.size() + 1 + mode.size() + 1; + } + + size_t Serialize(uint8_t* buffer) const override { + size_t i = 0; + buffer[i++] = opcode >> 8; + buffer[i++] = opcode; + + i += filename.copy(reinterpret_cast(buffer + i), filename.size()); + buffer[i++] = 0; + + i += mode.copy(reinterpret_cast(buffer + i), mode.size()); + buffer[i++] = 0; + + return i; + } + + static TftpWriteReadRequestPacket Deserialize(const uint8_t* buffer) { + size_t i = 0; + + TftpWriteReadRequestPacket self(buffer[i] << 8 | buffer[i + 1]); + i += 2; + + self.filename = reinterpret_cast(buffer + i); + i += self.filename.size() + 1; + + self.mode = reinterpret_cast(buffer + i); + i += self.mode.size() + 1; + + return self; + } +}; + +struct TftpErrorPacket : public TftpPacket +{ + uint16_t errorCode; + std::string message; + + TftpErrorPacket() : TftpPacket(TFTP_OP_ERROR) {} + TftpErrorPacket(uint16_t errorCode, std::string message) : + TftpPacket(TFTP_OP_ERROR), errorCode(errorCode), message(message) + {} + + constexpr size_t SerializedLength() const override + { + return sizeof(opcode) + sizeof(errorCode) + message.size() + 1; + } + + size_t Serialize(uint8_t* buffer) const + { + size_t i = 0; + buffer[i++] = opcode >> 8; + buffer[i++] = opcode; + buffer[i++] = errorCode >> 8; + buffer[i++] = errorCode; + + i += message.copy(reinterpret_cast(buffer + i), message.size()); + buffer[i++] = 0; + + return i; + } +}; + +struct TftpAcknowledgementPacket : public TftpPacket +{ + uint16_t blockNumber; + + TftpAcknowledgementPacket() : TftpPacket(TFTP_OP_ACKNOWLEDGEMENT) {} + + TftpAcknowledgementPacket(uint16_t blockNumber) : + TftpPacket(TFTP_OP_ACKNOWLEDGEMENT), blockNumber(blockNumber) + {} + + constexpr size_t SerializedLength() override + { + return sizeof(opcode) + sizeof(blockNumber); + } + + size_t Serialize(uint8_t* buffer) const override + { + size_t i = 0; + buffer[i++] = opcode >> 8; + buffer[i++] = opcode; + buffer[i++] = blockNumber >> 8; + buffer[i++] = blockNumber; + return i; + } +}; + +struct TftpDataPacket +{ + uint16_t opcode; + uint16_t blockNumber; + std::vector data; + + TftpDataPacket() : opcode(TFTP_OP_DATA) {} + + static TftpDataPacket Deserialize(const uint8_t* buffer, size_t length) + { + TftpDataPacket self; + self.opcode = buffer[0] << 8 | buffer[1]; + self.blockNumber = buffer[2] << 8 | buffer[3]; + self.data = std::vector(buffer + 4, buffer + length); + return self; + } +}; diff --git a/src/net-udp.h b/src/net-udp.h index 56e1596..20cdc9c 100644 --- a/src/net-udp.h +++ b/src/net-udp.h @@ -1,21 +1,14 @@ #pragma once #include - -enum TftpOperation -{ - TFTP_OP_READ_REQUEST = 1, - TFTP_OP_WRITE_REQUEST = 2, - TFTP_OP_DATA = 3, - TFTP_OP_ACKNOWLEDGEMENT = 4, - TFTP_OP_ERROR = 5, -}; +#include +#include struct UdpDatagramHeader { - std::uint16_t sourcePort; - std::uint16_t destinationPort; - std::uint16_t length; - std::uint16_t checksum; + uint16_t sourcePort; + uint16_t destinationPort; + uint16_t length; + uint16_t checksum; UdpDatagramHeader() {} @@ -73,120 +66,3 @@ struct UdpDatagram payload(payload) {} } __attribute__((packed)); - -struct TftpWriteReadRequestPacket -{ - uint16_t opcode; - std::string filename; - std::string mode; - - size_t SerializedLength() const { - return sizeof(opcode) + filename.size() + 1 + mode.size() + 1; - } - - size_t Serialize(uint8_t* buffer) const { - size_t i = 0; - buffer[i++] = opcode >> 8; - buffer[i++] = opcode; - - i += filename.copy(reinterpret_cast(buffer + i), filename.size()); - buffer[i++] = 0; - - i += mode.copy(reinterpret_cast(buffer + i), mode.size()); - buffer[i++] = 0; - - return i; - } - - static TftpWriteReadRequestPacket Deserialize(const uint8_t* buffer) { - TftpWriteReadRequestPacket self; - size_t i = 0; - - self.opcode = buffer[i] << 8 | buffer[i + 1]; - i += 2; - - self.filename = reinterpret_cast(buffer + i); - i += self.filename.size() + 1; - - self.mode = reinterpret_cast(buffer + i); - i += self.mode.size() + 1; - - return self; - } -}; - -struct TftpErrorPacket -{ - uint16_t opcode; - uint16_t errorCode; - std::string message; - - TftpErrorPacket() : opcode(TFTP_OP_ERROR) {} - TftpErrorPacket(uint16_t errorCode, std::string message) : - opcode(TFTP_OP_ERROR), errorCode(errorCode), message(message) - {} - - constexpr size_t SerializedLength() - { - return sizeof(opcode) + sizeof(errorCode) + message.size() + 1; - } - - size_t Serialize(uint8_t* buffer) const - { - size_t i = 0; - buffer[i++] = opcode >> 8; - buffer[i++] = opcode; - buffer[i++] = errorCode >> 8; - buffer[i++] = errorCode; - - i += message.copy(reinterpret_cast(buffer + i), message.size()); - buffer[i++] = 0; - - return i; - } -}; - -struct TftpAcknowledgementPacket -{ - uint16_t opcode; - uint16_t blockNumber; - - TftpAcknowledgementPacket() : opcode(TFTP_OP_ACKNOWLEDGEMENT) {} - - TftpAcknowledgementPacket(uint16_t blockNumber) : - opcode(TFTP_OP_ACKNOWLEDGEMENT), blockNumber(blockNumber) - {} - - constexpr size_t SerializedLength() - { - return sizeof(opcode) + sizeof(blockNumber); - } - - size_t Serialize(uint8_t* buffer) const - { - size_t i = 0; - buffer[i++] = opcode >> 8; - buffer[i++] = opcode; - buffer[i++] = blockNumber >> 8; - buffer[i++] = blockNumber; - return i; - } -}; - -struct TftpDataPacket -{ - uint16_t opcode; - uint16_t blockNumber; - std::vector data; - - TftpDataPacket() : opcode(TFTP_OP_DATA) {} - - static TftpDataPacket Deserialize(const uint8_t* buffer, size_t length) - { - TftpDataPacket self; - self.opcode = buffer[0] << 8 | buffer[1]; - self.blockNumber = buffer[2] << 8 | buffer[3]; - self.data = std::vector(buffer + 4, buffer + length); - return self; - } -}; diff --git a/src/net.cpp b/src/net.cpp index 2878d8d..e591127 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1,14 +1,16 @@ -#include "net.h" -#include "net-ethernet.h" -#include "net-ipv4.h" -#include "net-arp.h" -#include "net-icmp.h" -#include "net-udp.h" +#include +#include "ff.h" +#include "net-arp.h" +#include "net-ethernet.h" +#include "net-icmp.h" +#include "net-ipv4.h" +#include "net-udp.h" +#include "net.h" #include "types.h" + #include #include -#include "ff.h" // // ARP @@ -131,255 +133,6 @@ void HandleUdpFrame(const uint8_t* buffer) } } -FIL tftpFp; -bool shouldReboot = false; -std::string tftpPrevCwd; - -void HandleTftpDatagram( - const EthernetFrameHeader ethernetReqHeader, - const Ipv4Header ipv4ReqHeader, - const UdpDatagramHeader udpReqHeader, - const uint8_t* data -) { - const auto opcode = data[0] << 8 | data[1]; - static auto currentBlockNumber = -1; - if (opcode == TFTP_OP_WRITE_REQUEST) - { - auto packet = TftpWriteReadRequestPacket::Deserialize(data); - if (packet.mode == "octet") - { - currentBlockNumber = 0; - - // Try opening the file - { - char cwd[256]; - f_getcwd(cwd, sizeof(cwd)); - tftpPrevCwd = cwd; - } - - auto separator = packet.filename.rfind('/', packet.filename.size()); - if (separator != std::string::npos) - { - auto path = "/" + packet.filename.substr(0, separator); - f_chdir(path.c_str()); - } - else - { - f_chdir("/"); - separator = 0; - } - - auto filename = packet.filename.substr(separator + 1); - const auto result = f_open(&tftpFp, filename.c_str(), FA_CREATE_ALWAYS | FA_WRITE); - if (result == FR_OK) - { - shouldReboot = - packet.filename == "kernel.img" || packet.filename == "options.txt"; - - TftpAcknowledgementPacket response(currentBlockNumber); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - } - else - { - TftpErrorPacket response(0, "error opening target file"); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - } - } - else - { - TftpErrorPacket response(0, "please use mode octet"); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - } - } - else if (opcode == TFTP_OP_DATA) - { - auto packet = TftpDataPacket::Deserialize( - data, - udpReqHeader.length - UdpDatagramHeader::SerializedLength() - ); - - if (packet.blockNumber == currentBlockNumber + 1) - { - const auto last = packet.data.size() < 512; - currentBlockNumber = packet.blockNumber; - - unsigned int bytesWritten; - const auto response = - f_write(&tftpFp, packet.data.data(), packet.data.size(), &bytesWritten); - if (response == FR_OK || bytesWritten != packet.data.size()) - { - TftpAcknowledgementPacket response(currentBlockNumber); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - - if (last) - { - MsDelay(500); - f_close(&tftpFp); - f_chdir(tftpPrevCwd.c_str()); - tftpPrevCwd.clear(); - - if (shouldReboot) - { - // TODO eww - extern void Reboot_Pi(); - Reboot_Pi(); - } - } - } - else - { - f_close(&tftpFp); - - TftpErrorPacket response(0, "io error"); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - } - } - else - { - TftpErrorPacket response(0, "invalid block number"); - UdpDatagramHeader udpRespHeader( - udpReqHeader.destinationPort, - udpReqHeader.sourcePort, - response.SerializedLength() + UdpDatagramHeader::SerializedLength() - ); - Ipv4Header ipv4RespHeader( - IP_PROTO_UDP, - Ipv4Address, - ipv4ReqHeader.sourceIp, - udpRespHeader.length + Ipv4Header::SerializedLength() - ); - EthernetFrameHeader ethernetRespHeader( - ArpTable[ipv4RespHeader.destinationIp], - GetMacAddress(), - ETHERTYPE_IPV4 - ); - - size_t i = 0; - uint8_t buffer[USPI_FRAME_BUFFER_SIZE]; - i += ethernetRespHeader.Serialize(buffer + i); - i += ipv4RespHeader.Serialize(buffer + i); - i += udpRespHeader.Serialize(buffer + i); - i += response.Serialize(buffer + i); - USPiSendFrame(buffer, i); - } - } -} - // // ICMP //