Skip to content

Commit

Permalink
Kernel: Support for sending UDP packets
Browse files Browse the repository at this point in the history
This kind of works. For some reason it doesn't work the first time, and works most of the time subsequent times. There might be some sort of memory corruption bug because sometimes we RSOD. Too tired to debug but this seems like a good stopping point
  • Loading branch information
byteduck committed Mar 15, 2024
1 parent 5333de5 commit 73b1229
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 29 deletions.
24 changes: 24 additions & 0 deletions kernel/api/ipv4.h
Expand Up @@ -67,6 +67,30 @@ struct __attribute__((packed)) IPv4Packet {
IPv4Address source_addr;
IPv4Address dest_addr;
uint8_t payload[];

[[nodiscard]] inline BigEndian<uint16_t> compute_checksum() const { return __compute_checksum(this); }

inline void set_checksum() {
checksum = 0;
checksum = compute_checksum();
}

private:
// Necessary to beat the alignment allegations. Scary, I know
[[nodiscard]] inline BigEndian<uint16_t> __compute_checksum(const void* voidptr) const {
uint32_t sum = 0;
auto* ptr = (const uint16_t*) voidptr;
size_t count = sizeof(IPv4Packet);
while (count > 1) {
sum += as_big_endian(*ptr++);
if (sum & 0x80000000)
sum = (sum & 0xffff) | (sum >> 16);
count -= 2;
}
while (sum >> 16)
sum = (sum & 0xffff) + (sum >> 16);
return ~sum & 0xffff;
}
};

static_assert(sizeof(IPv4Packet) == 20);
Expand Down
7 changes: 7 additions & 0 deletions kernel/api/net.h
Expand Up @@ -35,6 +35,13 @@ class __attribute__((packed)) MACAddress {
return false;
}

inline constexpr operator bool() const {
for (auto& val : m_data)
if (val)
return true;
return false;
}

private:
uint8_t m_data[6] = {0};
};
Expand Down
66 changes: 60 additions & 6 deletions kernel/net/IPSocket.cpp
Expand Up @@ -22,32 +22,40 @@ ResultRet<kstd::Arc<IPSocket>> IPSocket::make(Socket::Type type, int protocol) {
}

Result IPSocket::bind(SafePointer<sockaddr> addr_ptr, socklen_t addrlen) {
LOCK(m_lock);

if (m_bound || addrlen != sizeof(sockaddr_in))
return Result(set_error(EINVAL));

auto addr = addr_ptr.as<sockaddr_in>().get();
if (addr.sin_family != AF_INET)
return Result(set_error(EINVAL));

m_port = from_big_endian(addr.sin_port);
m_addr = IPv4Address(from_big_endian(addr.sin_addr.s_addr));
m_bound_port = from_big_endian(addr.sin_port);
m_bound_addr = IPv4Address(from_big_endian(addr.sin_addr.s_addr));

return do_bind();
}

ssize_t IPSocket::recvfrom(FileDescriptor& fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> src_addr, SafePointer<socklen_t> addrlen) {
m_receive_queue_lock.acquire();

// Verify addrlen ptr
if (addrlen && addrlen.get() != sizeof(sockaddr_in))
return -set_error(EINVAL);

// Block until we have a packet to read
while (m_receive_queue.empty()) {
if (fd.nonblock()) {
m_receive_queue_lock.release();
return -EAGAIN;
return -set_error(EAGAIN);
}

update_blocker();
m_receive_queue_lock.release();
TaskManager::current_thread()->block(m_receive_blocker);
if (m_receive_blocker.was_interrupted())
return -set_error(EINTR);
m_receive_queue_lock.acquire();
}

Expand All @@ -56,21 +64,67 @@ ssize_t IPSocket::recvfrom(FileDescriptor& fd, SafePointer<uint8_t> buf, size_t
update_blocker();
m_receive_queue_lock.release();
auto res = do_recv(packet, buf, len);

// Write out addr
if (src_addr && addrlen) {
src_addr.as<sockaddr_in>().set({
AF_INET,
as_big_endian(packet->port),
packet->packet.source_addr.val()
});
addrlen.set(sizeof(sockaddr_in));
}

kfree(packet);
return res;
}

ssize_t IPSocket::sendto(FileDescriptor& fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> dest_addr, socklen_t addrlen) {
LOCK(m_lock);
if (dest_addr) {
if (addrlen != sizeof(sockaddr_in))
return -set_error(EINVAL);

auto addr = dest_addr.as<sockaddr_in>().get();
if (addr.sin_family != AF_INET)
return -set_error(EAFNOSUPPORT);

if (m_type != Stream) {
m_dest_addr = addr.sin_addr.s_addr;
m_dest_port = from_big_endian(addr.sin_port);
} else {
// TODO: TCP. We want to use connect() for that
}
}

if (!m_bound) {
// If we're not bound, bind to 0.0.0.0:0
m_bound_port = 0;
m_bound_addr = {};
auto res = do_bind();
if (res.is_error())
return -res.code();
}

// TODO: Adapter binding?

auto send_res = do_send(buf, len);
if (send_res.is_error())
return -send_res.code();
return (ssize_t) send_res.value();
}

Result IPSocket::recv_packet(const void* buf, size_t len) {
LOCK(m_receive_queue_lock);

if (m_receive_queue.size() == m_receive_queue.capacity()) {
KLog::warn("IPSocket", "Dropping packet because receive queue is full");
return Result(ENOSPC);
return Result(set_error(ENOSPC));
}

auto* src_pkt = (const IPv4Packet*) buf;
auto* new_pkt = (IPv4Packet*) kmalloc(len);
memcpy(new_pkt, src_pkt, len);
auto* new_pkt = new RecvdPacket;
memcpy(&new_pkt->packet, src_pkt, len);

m_receive_queue.push_back(new_pkt);
update_blocker();
Expand Down
18 changes: 14 additions & 4 deletions kernel/net/IPSocket.h
Expand Up @@ -14,6 +14,7 @@ class IPSocket: public Socket {
// Socket
Result bind(SafePointer<sockaddr> addr, socklen_t addrlen) override;
ssize_t recvfrom(FileDescriptor &fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> src_addr, SafePointer<socklen_t> addrlen) override;
ssize_t sendto(FileDescriptor &fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> dest_addr, socklen_t addrlen) override;
Result recv_packet(const void* buf, size_t len) override;

// File
Expand All @@ -22,15 +23,24 @@ class IPSocket: public Socket {
protected:
IPSocket(Socket::Type type, int protocol);

virtual ssize_t do_recv(const IPv4Packet* pkt, SafePointer<uint8_t> buf, size_t len) = 0;
struct RecvdPacket {
uint16_t port;
IPv4Packet packet; // Not actually set until we do do_recv
};

virtual ssize_t do_recv(RecvdPacket* pkt, SafePointer<uint8_t> buf, size_t len) = 0;
virtual Result do_bind() = 0;
virtual ResultRet<size_t> do_send(SafePointer<uint8_t> buf, size_t len) = 0;

void update_blocker();

bool m_bound = false;
uint16_t m_port;
IPv4Address m_addr;
kstd::circular_queue<IPv4Packet*> m_receive_queue { 16 };
uint16_t m_bound_port, m_dest_port;
IPv4Address m_bound_addr, m_dest_addr;
kstd::circular_queue<RecvdPacket*> m_receive_queue { 16 };
Mutex m_receive_queue_lock { "IPSocket::receive_queue" };
Mutex m_lock { "IPSocket::lock" };
BooleanBlocker m_receive_blocker;
uint8_t m_type_of_service = 0;
uint8_t m_ttl = 64;
};
49 changes: 48 additions & 1 deletion kernel/net/NetworkAdapter.cpp
Expand Up @@ -55,7 +55,7 @@ void NetworkAdapter::receive_bytes(SafePointer<uint8_t> bytes, size_t count) {

int i;
for (i = 0; i < 32; i++) {
if (!m_packets[i].used)
if (!m_packets[i].used.load(MemoryOrder::Acquire))
break;
}
if (i == 32) {
Expand Down Expand Up @@ -102,3 +102,50 @@ NetworkAdapter::Packet* NetworkAdapter::dequeue_packet() {
m_packet_queue = m_packet_queue->next;
return pkt;
}

NetworkAdapter::Packet* NetworkAdapter::alloc_packet(size_t size) {
ASSERT(size < 8192);
TaskManager::ScopedCritical crit;
int i;
for (i = 0; i < 32; i++) {
if (!m_packets[i].used.load(MemoryOrder::Acquire))
break;
}

if (i == 32)
return nullptr;

auto& pkt = m_packets[i];
pkt.size = sizeof(FrameHeader) + size;
return &m_packets[i];
}

IPv4Packet* NetworkAdapter::setup_ipv4_packet(Packet* packet, const MACAddress& dest, const IPv4Address& dest_addr, IPv4Proto proto, size_t payload_size, uint8_t dscp, uint8_t ttl) {
ASSERT(packet);

auto* frame = (FrameHeader*) packet->buffer;
frame->type = EtherProto::IPv4;
frame->destination = dest;
frame->source = m_mac_addr;

auto* ipv4 = (IPv4Packet*) (packet->buffer + sizeof(FrameHeader));
ipv4->source_addr = m_ipv4_addr;
ipv4->dest_addr = dest_addr;
ipv4->length = payload_size + sizeof(IPv4Packet);
ipv4->dscp_ecn = dscp;
ipv4->ttl = ttl;
ipv4->proto = proto;
ipv4->identification = 1;
ipv4->version_ihl = (4 << 4) | 5;
ipv4->identification = 1;
ipv4->set_checksum();

return ipv4;
}

void NetworkAdapter::send_packet(NetworkAdapter::Packet* packet) {
ASSERT(packet->size < 8192);
send_raw_packet(KernelPointer(packet->buffer), packet->size);
packet->used.store(false, MemoryOrder::Release);
}

5 changes: 4 additions & 1 deletion kernel/net/NetworkAdapter.h
Expand Up @@ -18,7 +18,7 @@ class NetworkAdapter: public kstd::ArcSelf<NetworkAdapter> {
uint8_t buffer[8192]; /* TODO: We need non-constant packet sizes... */
union {
size_t size;
bool used = false;
Atomic<bool> used = false;
};
Packet* next = nullptr;
};
Expand All @@ -32,7 +32,10 @@ class NetworkAdapter: public kstd::ArcSelf<NetworkAdapter> {

void send_arp_packet(MACAddress dest, const ARPPacket& packet);
void send_raw_packet(SafePointer<uint8_t> bytes, size_t count);
void send_packet(Packet* packet);
Packet* dequeue_packet();
Packet* alloc_packet(size_t size);
IPv4Packet* setup_ipv4_packet(Packet* packet, const MACAddress& dest, const IPv4Address& dest_addr, IPv4Proto proto, size_t payload_size, uint8_t dscp, uint8_t ttl);

[[nodiscard]] IPv4Address ipv4_address() const { return m_ipv4_addr; }
[[nodiscard]] MACAddress mac_address() const { return m_mac_addr; }
Expand Down
2 changes: 1 addition & 1 deletion kernel/net/NetworkManager.cpp
Expand Up @@ -31,7 +31,7 @@ void NetworkManager::do_task() {
for (auto& iface : NetworkAdapter::interfaces()) {
while ((packet = iface->dequeue_packet())) {
handle_packet(iface, packet);
packet->used = false;
packet->used.store(false, MemoryOrder::Release);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion kernel/net/Router.cpp
Expand Up @@ -102,7 +102,7 @@ Router::Route Router::get_route(const IPv4Address& dest, const IPv4Address& sour
}

// ARP lookup
KLog::dbg_if<ROUTE_DEBUG>("Router", "Could not find route to {}, sending ARP request thru {} for {}", dest, adapter->name(), next_hop);
KLog::dbg_if<ROUTE_DEBUG>("Router", "Could not find route to {}, looking up ARP entry for {}", dest, next_hop);
auto mac = arp_lookup(next_hop, adapter);
if (mac.is_error())
return {{}, {}};
Expand Down
1 change: 1 addition & 0 deletions kernel/net/Socket.h
Expand Up @@ -26,6 +26,7 @@ class Socket: public File {
// Socket
virtual Result bind(SafePointer<sockaddr> addr, socklen_t addrlen) = 0;
virtual ssize_t recvfrom(FileDescriptor& fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> src_addr, SafePointer<socklen_t> addrlen) = 0;
virtual ssize_t sendto(FileDescriptor& fd, SafePointer<uint8_t> buf, size_t len, int flags, SafePointer<sockaddr> dest_addr, socklen_t addrlen) = 0;
virtual Result recv_packet(const void* buf, size_t len) = 0;

[[nodiscard]] int error() const { return m_error; }
Expand Down
48 changes: 36 additions & 12 deletions kernel/net/UDPSocket.cpp
Expand Up @@ -4,6 +4,7 @@
#include "UDPSocket.h"
#include "../kstd/KLog.h"
#include "../api/udp.h"
#include "Router.h"

#define UDP_DBG 1

Expand All @@ -16,9 +17,9 @@ UDPSocket::UDPSocket(): IPSocket(Type::Dgram, 0) {

UDPSocket::~UDPSocket() {
LOCK(s_sockets_lock);
if (m_bound && s_sockets.contains(m_port)) {
s_sockets.erase(m_port);
KLog::dbg_if<UDP_DBG>("UDPSocket", "Unbinding from port {}", m_port);
if (m_bound && s_sockets.contains(m_bound_port)) {
s_sockets.erase(m_bound_port);
KLog::dbg_if<UDP_DBG>("UDPSocket", "Unbinding from port {}", m_bound_port);
}
}

Expand All @@ -40,12 +41,10 @@ Result UDPSocket::do_bind() {
LOCK(s_sockets_lock);
if (m_bound)
return Result(set_error(EINVAL));
if (s_sockets.contains(m_port))
if (s_sockets.contains(m_bound_port))
return Result(set_error(EADDRINUSE));

KLog::dbg_if<UDP_DBG>("UDPSocket", "Binding to port {}", m_port);

if (m_port == 0) {
if (m_bound_port == 0) {
// If we didn't specify a port, we want an ephemeral port
// (Range suggested by IANA and RFC 6335)
uint16_t ephem;
Expand All @@ -59,22 +58,47 @@ Result UDPSocket::do_bind() {
return Result(set_error(EADDRINUSE));
}

m_port = ephem;
m_bound_port = ephem;
}

s_sockets[m_port] = self();
KLog::dbg_if<UDP_DBG>("UDPSocket", "Binding to port {}", m_bound_port);

s_sockets[m_bound_port] = self();
m_bound = true;

return Result(SUCCESS);
}

ssize_t UDPSocket::do_recv(const IPv4Packet* pkt, SafePointer<uint8_t> buf, size_t len) {
auto* udp_pkt = (const UDPPacket*) pkt->payload;
ASSERT(pkt->length >= sizeof(IPv4Packet) + sizeof(UDPPacket)); // Should've been rejected at IP layer
ssize_t UDPSocket::do_recv(RecvdPacket* pkt, SafePointer<uint8_t> buf, size_t len) {
auto* udp_pkt = (const UDPPacket*) pkt->packet.payload;
ASSERT(pkt->packet.length >= sizeof(IPv4Packet) + sizeof(UDPPacket)); // Should've been rejected at IP layer
ASSERT(udp_pkt->len >= sizeof(UDPPacket)); // Should've been rejected in NetworkManager

const size_t nread = min(len, udp_pkt->len.val() - sizeof(UDPPacket));
buf.write(udp_pkt->payload, nread);

KLog::dbg_if<UDP_DBG>("UDPSocket", "Received packet from {}:{} ({} bytes)", pkt->packet.source_addr, udp_pkt->source_port, nread);

pkt->port = udp_pkt->source_port;

return (ssize_t) nread;
}

ResultRet<size_t> UDPSocket::do_send(SafePointer<uint8_t> buf, size_t len) {
auto route = Router::get_route(m_dest_addr, {}, {});
if (!route.mac || !route.adapter)
return Result(set_error(EHOSTUNREACH));

const size_t packet_len = sizeof(IPv4Packet) + sizeof(UDPPacket) + len;
auto pkt = route.adapter->alloc_packet(packet_len);
auto* ipv4_packet = route.adapter->setup_ipv4_packet(pkt, route.mac, m_dest_addr, UDP, sizeof(UDPPacket) + len, m_type_of_service, m_ttl);
auto* udp_packet = (UDPPacket*) ipv4_packet->payload;
udp_packet->source_port = m_bound_port;
udp_packet->dest_port = m_dest_port;
udp_packet->len = sizeof(UDPPacket) + len;
buf.read(udp_packet->payload, len);

KLog::dbg_if<UDP_DBG>("UDPSocket", "Sending packet to {}:{} ({} bytes)", m_dest_addr, m_dest_port, len);
route.adapter->send_packet(pkt);
return len;
}

0 comments on commit 73b1229

Please sign in to comment.