Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP/UDP: new function is_listening: t -> ~port:int -> callback option #508

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/tcp.ml
Expand Up @@ -34,6 +34,7 @@ module type S = sig
val writev_nodelay: flow -> Cstruct.t list -> (unit, write_error) result Lwt.t
val create_connection: ?keepalive:Keepalive.t -> t -> ipaddr * int -> (flow, error) result Lwt.t
val listen : t -> port:int -> ?keepalive:Keepalive.t -> (flow -> unit Lwt.t) -> unit
val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
end
8 changes: 6 additions & 2 deletions src/core/tcp.mli
Expand Up @@ -83,8 +83,12 @@ module type S = sig
executed for each flow that was established. If [keepalive] is provided,
this configuration will be applied before calling [callback].

@raise Invalid_argument if [port < 0] or [port > 65535]
*)
@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listener on [port]. *)
Expand Down
1 change: 1 addition & 0 deletions src/core/udp.ml
Expand Up @@ -6,6 +6,7 @@ module type S = sig
val disconnect : t -> unit Lwt.t
type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t
val listen : t -> port:int -> callback -> unit
val is_listening : t -> port:int -> callback option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
val write: ?src:ipaddr -> ?src_port:int -> ?ttl:int -> dst:ipaddr -> dst_port:int -> t -> Cstruct.t ->
Expand Down
5 changes: 5 additions & 0 deletions src/core/udp.mli
Expand Up @@ -29,6 +29,11 @@ module type S = sig

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> callback option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listeners on [port]. *)

Expand Down
11 changes: 7 additions & 4 deletions src/stack-unix/tcpv4v6_socket.ml
Expand Up @@ -26,7 +26,7 @@ type flow = Lwt_unix.file_descr
type t = {
interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *)
mutable active_connections : Lwt_unix.file_descr list;
listen_sockets : (int, Lwt_unix.file_descr list) Hashtbl.t;
listen_sockets : (int, Lwt_unix.file_descr list * (flow -> unit Lwt.t)) Hashtbl.t;
mutable switched_off : unit Lwt.t;
}

Expand Down Expand Up @@ -63,7 +63,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =
let disconnect t =
Lwt_list.iter_p close t.active_connections >>= fun () ->
Lwt_list.iter_p close
(Hashtbl.fold (fun _ fd acc -> fd @ acc) t.listen_sockets []) >>= fun () ->
(Hashtbl.fold (fun _ (fds, _) acc -> fds @ acc) t.listen_sockets []) >>= fun () ->
Lwt.cancel t.switched_off ; Lwt.return_unit

let dst fd =
Expand Down Expand Up @@ -113,10 +113,13 @@ let create_connection ?keepalive t (dst,dst_port) =
let unlisten t ~port =
match Hashtbl.find_opt t.listen_sockets port with
| None -> ()
| Some fds ->
| Some (fds, _) ->
Hashtbl.remove t.listen_sockets port;
try List.iter (fun fd -> Unix.close (Lwt_unix.unix_file_descr fd)) fds with _ -> ()

let is_listening t ~port =
Option.map snd (Hashtbl.find_opt t.listen_sockets port)

let listen t ~port ?keepalive callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port));
Expand Down Expand Up @@ -147,7 +150,7 @@ let listen t ~port ?keepalive callback =
in
List.iter (fun (fd, addr) ->
Unix.bind (Lwt_unix.unix_file_descr fd) addr;
Hashtbl.replace t.listen_sockets port (List.map fst fds);
Hashtbl.replace t.listen_sockets port (List.map fst fds, callback);
Lwt_unix.listen fd 10;
(* FIXME: we should not ignore the result *)
Lwt.async (fun () ->
Expand Down
26 changes: 16 additions & 10 deletions src/stack-unix/udpv4v6_socket.ml
Expand Up @@ -27,7 +27,7 @@ let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified

type t = {
interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *)
listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option) Hashtbl.t; (* UDP fds bound to a particular port *)
listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option * callback) Hashtbl.t; (* UDP fds bound to a particular port *)
mutable switched_off : unit Lwt.t;
}

Expand All @@ -38,12 +38,12 @@ let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;interface;_} port =
let get_udpv4v6_listening_fd ?preserve ?(v4_or_v6 = `Both) {listen_fds;interface;_} port =
try
Lwt.return
(match Hashtbl.find listen_fds port with
| (fd, None) -> false, [ fd ]
| (fd, Some fd') -> false, [ fd ; fd' ])
| (fd, None, _) -> false, [ fd ]
| (fd, Some fd', _) -> false, [ fd ; fd' ])
with Not_found ->
(match interface with
| `Any ->
Expand Down Expand Up @@ -76,8 +76,8 @@ let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;
| `V6_only ip ->
let fd = Lwt_unix.(socket PF_INET6 SOCK_DGRAM 0) in
Lwt_unix.bind fd (Lwt_unix.ADDR_INET (ip, port)) >|= fun () ->
((fd, None), [ fd ])) >|= fun (fds, r) ->
if preserve then Hashtbl.add listen_fds port fds;
((fd, None), [ fd ])) >|= fun ((fd1, fd2), r) ->
Option.iter (fun cb -> Hashtbl.add listen_fds port (fd1, fd2, cb)) preserve;
true, r


Expand Down Expand Up @@ -121,7 +121,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =
Lwt.return { interface; listen_fds; switched_off = fst (Lwt.wait ()) }

let disconnect t =
Hashtbl.fold (fun _ (fd, fd') r ->
Hashtbl.fold (fun _ (fd, fd', _) r ->
r >>= fun () ->
close fd >>= fun () ->
match fd' with None -> Lwt.return_unit | Some fd -> close fd)
Expand All @@ -146,7 +146,7 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf =
match t.interface, v4_or_v6 with
| `Any, _ | `Ip _, _ | `V4_only _, `V4 | `V6_only _, `V6 ->
let p = match src_port with None -> 0 | Some x -> x in
get_udpv4v6_listening_fd ~preserve:false ~v4_or_v6 t p >>= fun (created, fds) ->
get_udpv4v6_listening_fd ~v4_or_v6 t p >>= fun (created, fds) ->
((match fds, v4_or_v6 with
| [ fd ], _ -> Lwt.return (Ok fd)
| [ v4 ; _v6 ], `V4 -> Lwt.return (Ok v4)
Expand All @@ -161,19 +161,25 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf =

let unlisten t ~port =
try
let fd, fd' = Hashtbl.find t.listen_fds port in
let fd, fd', _ = Hashtbl.find t.listen_fds port in
Hashtbl.remove t.listen_fds port;
(match fd' with None -> () | Some fd' -> Unix.close (Lwt_unix.unix_file_descr fd'));
Unix.close (Lwt_unix.unix_file_descr fd)
with _ -> ()

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Option.map (fun (_, _, cb) -> cb) (Hashtbl.find_opt t.listen_fds port)

let listen t ~port callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
(* FIXME: we should not ignore the result *)
Lwt.async (fun () ->
get_udpv4v6_listening_fd t port >|= fun (_, fds) ->
get_udpv4v6_listening_fd ~preserve:callback t port >|= fun (_, fds) ->
List.iter (fun fd ->
Lwt.async (fun () ->
let buf = Cstruct.create 4096 in
Expand Down
6 changes: 6 additions & 0 deletions src/tcp/flow.ml
Expand Up @@ -83,6 +83,12 @@ struct
else
Hashtbl.replace t.listeners port (keepalive, cb)

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Option.map snd (Hashtbl.find_opt t.listeners port)

let unlisten t ~port = Hashtbl.remove t.listeners port

let _pp_pcb fmt pcb =
Expand Down
6 changes: 6 additions & 0 deletions src/udp/udp.ml
Expand Up @@ -40,6 +40,12 @@ module Make (Ip : Tcpip.Ip.S) (Random : Mirage_random.S) = struct
else
Hashtbl.replace t.listeners port callback

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Hashtbl.find_opt t.listeners port

let unlisten t ~port = Hashtbl.remove t.listeners port

(* TODO: ought we to check to make sure the destination is relevant
Expand Down