Split TFTP code into readable chunks and deduplicate a bunch of code

This commit is contained in:
Sijmen 2020-12-11 22:44:10 +01:00
parent 0c49a541ba
commit 60e186cdc9
Signed by: vijfhoek
GPG key ID: DAF7821E067D9C48
6 changed files with 323 additions and 387 deletions

View file

@ -4,7 +4,7 @@ OBJS = armc-start.o armc-cstartup.o armc-cstubs.o armc-cppstubs.o \
Drive.o Pi1541.o DiskImage.o iec_bus.o iec_commands.o m6502.o m6522.o \
gcr.o prot.o lz.o emmc.o diskio.o options.o Screen.o SSD1306.o ScreenLCD.o \
Timer.o FileBrowser.o DiskCaddy.o ROMs.o InputMappings.o xga_font_data.o m8520.o wd177x.o Pi1581.o SpinLock.o \
net.o
net.o net-tftp.o
SRCDIR = src
OBJS := $(addprefix $(SRCDIR)/, $(OBJS))

View file

@ -1,4 +1,5 @@
#pragma once
#include "net.h"
struct Ipv4ArpPacket
{

168
src/net-tftp.cpp Normal file
View file

@ -0,0 +1,168 @@
#include <memory>
#include "ff.h"
#include "net-ethernet.h"
#include "net-ipv4.h"
#include "net-tftp.h"
#include "net-udp.h"
#include "net.h"
#include "types.h"
#include <uspi.h>
// TODO Allow multiple files open
static FIL outFile;
static bool shouldReboot = false;
static uint32_t currentBlockNumber = -1;
static std::unique_ptr<TftpPacket> handleTftpWriteRequest(const uint8_t* data)
{
auto packet = TftpWriteReadRequestPacket::Deserialize(data);
// TODO Implement netscii, maybe
if (packet.mode != "octet")
{
return std::unique_ptr<TftpErrorPacket>(
new TftpErrorPacket(0, "please use mode octet")
);
}
currentBlockNumber = 0;
// TODO Return to the original working directory.
char workingDirectory[256];
f_getcwd(workingDirectory, sizeof(workingDirectory));
// Try opening the file
auto separator = packet.filename.rfind('/', packet.filename.size());
if (separator != std::string::npos)
{
auto path = "/" + packet.filename.substr(0, separator);
f_chdir(path.c_str());
}
else
{
f_chdir("/");
separator = 0;
}
// Open the output file.
auto filename = packet.filename.substr(separator + 1);
const auto result = f_open(&outFile, filename.c_str(), FA_CREATE_ALWAYS | FA_WRITE);
std::unique_ptr<TftpPacket> response;
if (result != FR_OK)
{
response = std::unique_ptr<TftpErrorPacket>(
new TftpErrorPacket(0, "error opening target file")
);
}
else
{
shouldReboot =
packet.filename == "kernel.img" || packet.filename == "options.txt";
response = std::unique_ptr<TftpAcknowledgementPacket>(
new TftpAcknowledgementPacket(currentBlockNumber)
);
}
// TODO Return to the original working directory here
return response;
}
static std::unique_ptr<TftpPacket> handleTftpData(const uint8_t* data, size_t length)
{
auto packet = TftpDataPacket::Deserialize(data, length);
if (packet.blockNumber != currentBlockNumber + 1)
{
f_close(&outFile);
return std::unique_ptr<TftpErrorPacket>(
new TftpErrorPacket(0, "invalid block number")
);
}
currentBlockNumber = packet.blockNumber;
unsigned int bytesWritten;
const auto result =
f_write(&outFile, packet.data.data(), packet.data.size(), &bytesWritten);
if (result != FR_OK || bytesWritten != packet.data.size())
{
f_close(&outFile);
return std::unique_ptr<TftpErrorPacket>(new TftpErrorPacket(0, "io error"));
}
if (packet.data.size() < 512)
{
// Close the file for the last packet.
f_close(&outFile);
}
return std::unique_ptr<TftpAcknowledgementPacket>(
new TftpAcknowledgementPacket(currentBlockNumber)
);
}
void HandleTftpDatagram(
const EthernetFrameHeader ethernetReqHeader,
const Ipv4Header ipv4ReqHeader,
const UdpDatagramHeader udpReqHeader,
const uint8_t* data
) {
const auto opcode = data[0] << 8 | data[1];
std::unique_ptr<TftpPacket> response;
bool last = false;
if (opcode == TFTP_OP_WRITE_REQUEST)
{
response = handleTftpWriteRequest(data);
}
else if (opcode == TFTP_OP_DATA)
{
const auto length = udpReqHeader.length - UdpDatagramHeader::SerializedLength();
response = handleTftpData(data, length);
}
else
{
response = std::unique_ptr<TftpErrorPacket>(
new TftpErrorPacket(4, "not implemented yet")
);
}
if (response != nullptr)
{
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response->SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response->Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
if (last && shouldReboot)
{
// TODO eww
extern void Reboot_Pi();
Reboot_Pi();
}
}

138
src/net-tftp.h Normal file
View file

@ -0,0 +1,138 @@
#pragma once
#include <vector>
const size_t TFTP_BLOCK_SIZE = 512;
enum TftpOperation
{
TFTP_OP_READ_REQUEST = 1,
TFTP_OP_WRITE_REQUEST = 2,
TFTP_OP_DATA = 3,
TFTP_OP_ACKNOWLEDGEMENT = 4,
TFTP_OP_ERROR = 5,
};
struct TftpPacket
{
uint16_t opcode;
TftpPacket(uint16_t opcode) : opcode(opcode) {}
virtual size_t SerializedLength() const = 0;
virtual size_t Serialize(uint8_t* buffer) const = 0;
};
struct TftpWriteReadRequestPacket : public TftpPacket
{
std::string filename;
std::string mode;
TftpWriteReadRequestPacket(uint16_t opcode) : TftpPacket(opcode) {}
constexpr size_t SerializedLength() override {
return sizeof(opcode) + filename.size() + 1 + mode.size() + 1;
}
size_t Serialize(uint8_t* buffer) const override {
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
i += filename.copy(reinterpret_cast<char*>(buffer + i), filename.size());
buffer[i++] = 0;
i += mode.copy(reinterpret_cast<char*>(buffer + i), mode.size());
buffer[i++] = 0;
return i;
}
static TftpWriteReadRequestPacket Deserialize(const uint8_t* buffer) {
size_t i = 0;
TftpWriteReadRequestPacket self(buffer[i] << 8 | buffer[i + 1]);
i += 2;
self.filename = reinterpret_cast<const char*>(buffer + i);
i += self.filename.size() + 1;
self.mode = reinterpret_cast<const char*>(buffer + i);
i += self.mode.size() + 1;
return self;
}
};
struct TftpErrorPacket : public TftpPacket
{
uint16_t errorCode;
std::string message;
TftpErrorPacket() : TftpPacket(TFTP_OP_ERROR) {}
TftpErrorPacket(uint16_t errorCode, std::string message) :
TftpPacket(TFTP_OP_ERROR), errorCode(errorCode), message(message)
{}
constexpr size_t SerializedLength() const override
{
return sizeof(opcode) + sizeof(errorCode) + message.size() + 1;
}
size_t Serialize(uint8_t* buffer) const
{
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
buffer[i++] = errorCode >> 8;
buffer[i++] = errorCode;
i += message.copy(reinterpret_cast<char*>(buffer + i), message.size());
buffer[i++] = 0;
return i;
}
};
struct TftpAcknowledgementPacket : public TftpPacket
{
uint16_t blockNumber;
TftpAcknowledgementPacket() : TftpPacket(TFTP_OP_ACKNOWLEDGEMENT) {}
TftpAcknowledgementPacket(uint16_t blockNumber) :
TftpPacket(TFTP_OP_ACKNOWLEDGEMENT), blockNumber(blockNumber)
{}
constexpr size_t SerializedLength() override
{
return sizeof(opcode) + sizeof(blockNumber);
}
size_t Serialize(uint8_t* buffer) const override
{
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
buffer[i++] = blockNumber >> 8;
buffer[i++] = blockNumber;
return i;
}
};
struct TftpDataPacket
{
uint16_t opcode;
uint16_t blockNumber;
std::vector<uint8_t> data;
TftpDataPacket() : opcode(TFTP_OP_DATA) {}
static TftpDataPacket Deserialize(const uint8_t* buffer, size_t length)
{
TftpDataPacket self;
self.opcode = buffer[0] << 8 | buffer[1];
self.blockNumber = buffer[2] << 8 | buffer[3];
self.data = std::vector<uint8_t>(buffer + 4, buffer + length);
return self;
}
};

View file

@ -1,21 +1,14 @@
#pragma once
#include <vector>
enum TftpOperation
{
TFTP_OP_READ_REQUEST = 1,
TFTP_OP_WRITE_REQUEST = 2,
TFTP_OP_DATA = 3,
TFTP_OP_ACKNOWLEDGEMENT = 4,
TFTP_OP_ERROR = 5,
};
#include <string>
#include <cstdint>
struct UdpDatagramHeader
{
std::uint16_t sourcePort;
std::uint16_t destinationPort;
std::uint16_t length;
std::uint16_t checksum;
uint16_t sourcePort;
uint16_t destinationPort;
uint16_t length;
uint16_t checksum;
UdpDatagramHeader() {}
@ -73,120 +66,3 @@ struct UdpDatagram
payload(payload)
{}
} __attribute__((packed));
struct TftpWriteReadRequestPacket
{
uint16_t opcode;
std::string filename;
std::string mode;
size_t SerializedLength() const {
return sizeof(opcode) + filename.size() + 1 + mode.size() + 1;
}
size_t Serialize(uint8_t* buffer) const {
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
i += filename.copy(reinterpret_cast<char*>(buffer + i), filename.size());
buffer[i++] = 0;
i += mode.copy(reinterpret_cast<char*>(buffer + i), mode.size());
buffer[i++] = 0;
return i;
}
static TftpWriteReadRequestPacket Deserialize(const uint8_t* buffer) {
TftpWriteReadRequestPacket self;
size_t i = 0;
self.opcode = buffer[i] << 8 | buffer[i + 1];
i += 2;
self.filename = reinterpret_cast<const char*>(buffer + i);
i += self.filename.size() + 1;
self.mode = reinterpret_cast<const char*>(buffer + i);
i += self.mode.size() + 1;
return self;
}
};
struct TftpErrorPacket
{
uint16_t opcode;
uint16_t errorCode;
std::string message;
TftpErrorPacket() : opcode(TFTP_OP_ERROR) {}
TftpErrorPacket(uint16_t errorCode, std::string message) :
opcode(TFTP_OP_ERROR), errorCode(errorCode), message(message)
{}
constexpr size_t SerializedLength()
{
return sizeof(opcode) + sizeof(errorCode) + message.size() + 1;
}
size_t Serialize(uint8_t* buffer) const
{
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
buffer[i++] = errorCode >> 8;
buffer[i++] = errorCode;
i += message.copy(reinterpret_cast<char*>(buffer + i), message.size());
buffer[i++] = 0;
return i;
}
};
struct TftpAcknowledgementPacket
{
uint16_t opcode;
uint16_t blockNumber;
TftpAcknowledgementPacket() : opcode(TFTP_OP_ACKNOWLEDGEMENT) {}
TftpAcknowledgementPacket(uint16_t blockNumber) :
opcode(TFTP_OP_ACKNOWLEDGEMENT), blockNumber(blockNumber)
{}
constexpr size_t SerializedLength()
{
return sizeof(opcode) + sizeof(blockNumber);
}
size_t Serialize(uint8_t* buffer) const
{
size_t i = 0;
buffer[i++] = opcode >> 8;
buffer[i++] = opcode;
buffer[i++] = blockNumber >> 8;
buffer[i++] = blockNumber;
return i;
}
};
struct TftpDataPacket
{
uint16_t opcode;
uint16_t blockNumber;
std::vector<uint8_t> data;
TftpDataPacket() : opcode(TFTP_OP_DATA) {}
static TftpDataPacket Deserialize(const uint8_t* buffer, size_t length)
{
TftpDataPacket self;
self.opcode = buffer[0] << 8 | buffer[1];
self.blockNumber = buffer[2] << 8 | buffer[3];
self.data = std::vector<uint8_t>(buffer + 4, buffer + length);
return self;
}
};

View file

@ -1,14 +1,16 @@
#include "net.h"
#include "net-ethernet.h"
#include "net-ipv4.h"
#include "net-arp.h"
#include "net-icmp.h"
#include "net-udp.h"
#include <memory>
#include "ff.h"
#include "net-arp.h"
#include "net-ethernet.h"
#include "net-icmp.h"
#include "net-ipv4.h"
#include "net-udp.h"
#include "net.h"
#include "types.h"
#include <uspi.h>
#include <uspios.h>
#include "ff.h"
//
// ARP
@ -131,255 +133,6 @@ void HandleUdpFrame(const uint8_t* buffer)
}
}
FIL tftpFp;
bool shouldReboot = false;
std::string tftpPrevCwd;
void HandleTftpDatagram(
const EthernetFrameHeader ethernetReqHeader,
const Ipv4Header ipv4ReqHeader,
const UdpDatagramHeader udpReqHeader,
const uint8_t* data
) {
const auto opcode = data[0] << 8 | data[1];
static auto currentBlockNumber = -1;
if (opcode == TFTP_OP_WRITE_REQUEST)
{
auto packet = TftpWriteReadRequestPacket::Deserialize(data);
if (packet.mode == "octet")
{
currentBlockNumber = 0;
// Try opening the file
{
char cwd[256];
f_getcwd(cwd, sizeof(cwd));
tftpPrevCwd = cwd;
}
auto separator = packet.filename.rfind('/', packet.filename.size());
if (separator != std::string::npos)
{
auto path = "/" + packet.filename.substr(0, separator);
f_chdir(path.c_str());
}
else
{
f_chdir("/");
separator = 0;
}
auto filename = packet.filename.substr(separator + 1);
const auto result = f_open(&tftpFp, filename.c_str(), FA_CREATE_ALWAYS | FA_WRITE);
if (result == FR_OK)
{
shouldReboot =
packet.filename == "kernel.img" || packet.filename == "options.txt";
TftpAcknowledgementPacket response(currentBlockNumber);
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
else
{
TftpErrorPacket response(0, "error opening target file");
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
}
else
{
TftpErrorPacket response(0, "please use mode octet");
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
}
else if (opcode == TFTP_OP_DATA)
{
auto packet = TftpDataPacket::Deserialize(
data,
udpReqHeader.length - UdpDatagramHeader::SerializedLength()
);
if (packet.blockNumber == currentBlockNumber + 1)
{
const auto last = packet.data.size() < 512;
currentBlockNumber = packet.blockNumber;
unsigned int bytesWritten;
const auto response =
f_write(&tftpFp, packet.data.data(), packet.data.size(), &bytesWritten);
if (response == FR_OK || bytesWritten != packet.data.size())
{
TftpAcknowledgementPacket response(currentBlockNumber);
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
if (last)
{
MsDelay(500);
f_close(&tftpFp);
f_chdir(tftpPrevCwd.c_str());
tftpPrevCwd.clear();
if (shouldReboot)
{
// TODO eww
extern void Reboot_Pi();
Reboot_Pi();
}
}
}
else
{
f_close(&tftpFp);
TftpErrorPacket response(0, "io error");
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
}
else
{
TftpErrorPacket response(0, "invalid block number");
UdpDatagramHeader udpRespHeader(
udpReqHeader.destinationPort,
udpReqHeader.sourcePort,
response.SerializedLength() + UdpDatagramHeader::SerializedLength()
);
Ipv4Header ipv4RespHeader(
IP_PROTO_UDP,
Ipv4Address,
ipv4ReqHeader.sourceIp,
udpRespHeader.length + Ipv4Header::SerializedLength()
);
EthernetFrameHeader ethernetRespHeader(
ArpTable[ipv4RespHeader.destinationIp],
GetMacAddress(),
ETHERTYPE_IPV4
);
size_t i = 0;
uint8_t buffer[USPI_FRAME_BUFFER_SIZE];
i += ethernetRespHeader.Serialize(buffer + i);
i += ipv4RespHeader.Serialize(buffer + i);
i += udpRespHeader.Serialize(buffer + i);
i += response.Serialize(buffer + i);
USPiSendFrame(buffer, i);
}
}
}
//
// ICMP
//