Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove possible file descriptor leak if safe_close_and_exec fails #5596

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion ocaml/forkexecd/lib/forkhelpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ let safe_close_and_exec ?env stdin stdout stderr
let fds_to_close = ref [] in

let add_fd_to_close_list fd = fds_to_close := fd :: !fds_to_close in
(* let remove_fd_from_close_list fd = fds_to_close := List.filter (fun fd' -> fd' <> fd) !fds_to_close in *)
let remove_fd_from_close_list fd =
fds_to_close := List.filter (fun fd' -> fd' <> fd) !fds_to_close
in
let close_fds () = List.iter (fun fd -> Unix.close fd) !fds_to_close in

add_fd_to_close_list sock ;

finally
(fun () ->
let maybe_add_id_to_fd_map id_to_fd_map (uuid, fd, v) =
Expand Down Expand Up @@ -285,6 +289,7 @@ let safe_close_and_exec ?env stdin stdout stderr
Fecomms.write_raw_rpc sock Fe.Exec ;
match Fecomms.read_raw_rpc sock with
| Ok (Fe.Execed pid) ->
remove_fd_from_close_list sock ;
(sock, pid)
| Ok status ->
let msg =
Expand Down
113 changes: 69 additions & 44 deletions ocaml/forkexecd/test/fe_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ let min_fds = 7

let max_fds = 1024 - 13 (* fe daemon has a bunch for its own use *)

let fail x =
freddy77 marked this conversation as resolved.
Show resolved Hide resolved
Xapi_stdext_unix.Unixext.write_string_to_file "/tmp/fe-test.log" x ;
Printf.fprintf stderr "%s\n" x ;
assert false

let fail fmt = Format.ksprintf fail fmt

let all_combinations fds =
let y =
{
Expand Down Expand Up @@ -68,8 +75,26 @@ let shuffle x =
done ;
Array.to_list arr

let fds_fold f init =
let path = "/proc/self/fd" in
(* get rid of the fd used to read the directory *)
Array.fold_right
(fun fd_num acc ->
try
let link = Unix.readlink (Filename.concat path fd_num) in
f fd_num link acc
with _ -> acc
)
(Sys.readdir path) init

let fd_list () = fds_fold (fun fd_num link l -> (fd_num, link) :: l) []

let fd_count () = fds_fold (fun _ _ n -> n + 1) 0

let irrelevant_strings = ["irrelevant"; "not"; "important"]

let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ())

let one fds x =
(*Printf.fprintf stderr "named_fds = %d\n" x.named_fds;
Printf.fprintf stderr "extra = %d\n" x.extra;*)
Expand All @@ -82,7 +107,6 @@ let one fds x =
let number_of_extra = x.extra in
let other_names = make_names number_of_extra in

let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let table =
(fun x -> List.combine x (List.map (fun _ -> fd) x)) (names @ other_names)
in
Expand All @@ -107,7 +131,6 @@ let one fds x =

let test_delay () =
let start = Unix.gettimeofday () in
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let args = ["sleep"] in
(* Need to have fractional part because some internal usage split integer
and fractional and do computation.
Expand All @@ -117,7 +140,7 @@ let test_delay () =
let timeout = 1.7 in
try
Forkhelpers.execute_command_get_output ~timeout exe args |> ignore ;
failwith "Failed to timeout"
fail "Failed to timeout"
with
| Forkhelpers.Subprocess_timeout ->
let elapsed = Unix.gettimeofday () -. start in
Expand All @@ -127,39 +150,25 @@ let test_delay () =
if elapsed > timeout +. 0.2 then
failwith "Excessive time elapsed"
| e ->
failwith
(Printf.sprintf "Failed with unexpected exception: %s"
(Printexc.to_string e)
)
fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let test_notimeout () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let args = ["sleep"] in
try
Forkhelpers.execute_command_get_output exe args |> ignore ;
()
with e ->
failwith
(Printf.sprintf "Failed with unexpected exception: %s"
(Printexc.to_string e)
)

let fail x =
Xapi_stdext_unix.Unixext.write_string_to_file "/tmp/fe-test.log" x ;
Printf.fprintf stderr "%s\n" x ;
assert false
with e -> fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let expect expected s =
if s <> expected ^ "\n" then
fail (Printf.sprintf "output %s expected %s" s expected)
fail "output %s expected %s" s expected

let test_exitcode () =
let run_expect cmd expected =
try Forkhelpers.execute_command_get_output cmd [] |> ignore
with Forkhelpers.Spawn_internal_error (_, _, Unix.WEXITED n) ->
if n <> expected then
fail
(Printf.sprintf "%s exited with code %d, expected %d" cmd n expected)
fail "%s exited with code %d, expected %d" cmd n expected
in
run_expect "/bin/false" 1 ;
run_expect "/bin/xe-fe-test-no-command" 127 ;
Expand All @@ -168,7 +177,6 @@ let test_exitcode () =
Printf.printf "\nCompleted exitcode tests\n"

let test_output () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let expected_out = "output string" in
let expected_err = "error string" in
let args = ["echo"; expected_out; expected_err] in
Expand All @@ -178,7 +186,6 @@ let test_output () =
print_endline "Completed output tests"

let test_input () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let input = "input string" in
let args = ["replay"] in
let out, _ =
Expand All @@ -187,6 +194,38 @@ let test_input () =
expect input out ;
print_endline "Completed input tests"

(* This test tests a failure inside Forkhelpers.safe_close_and_exec.
Although the exact way of this reproduction is never supposed to
happen in the real world, an internal failure could happen for instance
if forkexecd daemon is restarted for a moment, so make sure we are
able to detect and handle these cases *)
let test_internal_failure_error () =
let initial_fd_count = fd_count () in
let leak_fd_detect () =
let current_fd_count = fd_count () in
if current_fd_count <> initial_fd_count then
fail "File descriptor leak detected initially %d files, now %d"
initial_fd_count current_fd_count
in
(* this weird function will open and close "num" file descriptors
and returns the last (now closed) of them, mainly to get an invalid
file descriptor with some closed one before *)
let rec waste_fds num =
let fd = Unix.openfile "/dev/null" [Unix.O_WRONLY] 0o0 in
let ret = if num = 0 then fd else waste_fds (num - 1) in
Unix.close fd ; ret
in
let fd = waste_fds 20 in
let args = ["sleep"] in
try
Forkhelpers.safe_close_and_exec None (Some fd) None [] exe args |> ignore ;
fail "Expected an exception"
with
| Fd_send_recv.Unix_error _ ->
leak_fd_detect ()
| e ->
fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let master fds =
Printf.printf "\nPerforming timeout tests\n%!" ;
test_delay () ;
Expand All @@ -196,6 +235,8 @@ let master fds =
Printf.printf "\nPerforming input/output tests\n%!" ;
test_output () ;
test_input () ;
Printf.printf "\nPerforming internal failure test\n%!" ;
test_internal_failure_error () ;
let combinations = shuffle (all_combinations fds) in
Printf.printf "Starting %d tests\n%!" (List.length combinations) ;
let i = ref 0 in
Expand All @@ -215,28 +256,14 @@ let master fds =

let slave = function
| [] ->
failwith "Error, at least one fd expected"
fail "Error, at least one fd expected"
| total_fds :: rest ->
let total_fds = int_of_string total_fds in
let fds =
List.filter (fun x -> not (List.mem x irrelevant_strings)) rest
in
(* Check that these fds are present *)
let pid = Unix.getpid () in
let path = Printf.sprintf "/proc/%d/fd" pid in
let raw =
List.filter (* get rid of the fd used to read the directory *)
(fun x ->
try
ignore (Unix.readlink (Filename.concat path x)) ;
true
with _ -> false
)
(Array.to_list (Sys.readdir path))
in
let pairs =
List.map (fun x -> (x, Unix.readlink (Filename.concat path x))) raw
in
let pairs = fd_list () in
(* Filter any of stdin,stdout,stderr which have been mapped to /dev/null *)
let filtered =
List.filter
Expand All @@ -257,18 +284,16 @@ let slave = function
List.iter
(fun fd ->
if not (List.mem fd (List.map fst filtered)) then
fail (Printf.sprintf "fd %s not in /proc/%d/fd [ %s ]" fd pid ls)
fail "fd %s not in /proc/self/fd [ %s ]" fd ls
)
fds ;
(* Check that we have the expected number *)
(*
Printf.fprintf stderr "%s %d\n" total_fds (List.length present - 1)
*)
if total_fds <> List.length filtered then
fail
(Printf.sprintf "Expected %d fds; /proc/%d/fd has %d: %s" total_fds
pid (List.length filtered) ls
)
fail "Expected %d fds; /proc/self/fd has %d: %s" total_fds
(List.length filtered) ls

let sleep () = Unix.sleep 3 ; Printf.printf "Ok\n"

Expand Down