Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture_stdout causes crash, GPU, Lux.jl #2913

Open
Dale-Black opened this issue May 7, 2024 · 7 comments
Open

Capture_stdout causes crash, GPU, Lux.jl #2913

Dale-Black opened this issue May 7, 2024 · 7 comments
Labels
bug Something isn't working other packages Integration with other Julia packages

Comments

@Dale-Black
Copy link

Description:

I am encountering GPU memory issues while training a Lux model (1.4 million parameters, input image size (128, 128, 96)) in a Pluto notebook on an A100 GPU. The issue occurs around the 50th epoch, resulting in Malt.TerminatedWorkerException() errors. However, the same code runs without issues when executed outside of Pluto.

Environment:

  • CUDA v5.3.3
  • Lux v0.5.42
  • LuxCUDA v0.3.2

Steps to Reproduce:

  1. Open the Pluto notebook: https://github.com/Dale-Black/ComputerVisionTutorials.jl/blob/main/tutorials/03_image_segmentation.jl
  2. Run the notebook and observe the GPU memory usage and errors.

Expected Behavior:

The training should proceed without encountering GPU memory issues or Malt.TerminatedWorkerException() errors.

Actual Behavior:

The training encounters GPU memory issues and Malt.TerminatedWorkerException() errors around the 50th epoch when running in a Pluto notebook. The same code runs without issues when executed outside of Pluto.

Additional Context:

  • Initially, the issue was thought to be related to string interpolation in logging statements, which can introduce try-catches and force retain GPU memory (Unrelated try-catch causes CUDA arrays to not be freed JuliaLang/julia#52533). However, removing string interpolation did not resolve the problem.
  • The issue is reproducible in Pluto.jl but not when running the code outside of Pluto.
pluto.lux.error.mov
@fonsp
Copy link
Owner

fonsp commented May 8, 2024

Hey Dale!! Thanks for the clear bug report!

Do you see any errors in the terminal where you launched Pluto? Does Pluto.run(workspace_use_distributed_stdlib=true) fix it, or does this crash, but with additional error messages?

@fonsp fonsp added bug Something isn't working other packages Integration with other Julia packages labels May 8, 2024
@Dale-Black
Copy link
Author

Just tested using workspace_use_distributed_stdlib = true and got this error at around epoch 50

Worker 2 terminated.
Unhandled Task ERROR: EOFError: read end of file
Stacktrace:
 [1] (::Base.var"#wait_locked#739")(s::Sockets.TCPSocket, buf::IOBuffer, nb::Int64)
   @ Base ./stream.jl:947
 [2] unsafe_read(s::Sockets.TCPSocket, p::Ptr{UInt8}, nb::UInt64)
   @ Base ./stream.jl:955
 [3] unsafe_read
   @ ./io.jl:774 [inlined]
 [4] unsafe_read(s::Sockets.TCPSocket, p::Base.RefValue{NTuple{4, Int64}}, n::Int64)
   @ Base ./io.jl:773
 [5] read!
   @ ./io.jl:775 [inlined]
 [6] deserialize_hdr_raw
   @ ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/messages.jl:167 [inlined]
 [7] message_handler_loop(r_stream::Sockets.TCPSocket, w_stream::Sockets.TCPSocket, incoming::Bool)
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:172
 [8] process_tcp_streams(r_stream::Sockets.TCPSocket, w_stream::Sockets.TCPSocket, incoming::Bool)
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:133
 [9] (::Distributed.var"#103#104"{Sockets.TCPSocket, Sockets.TCPSocket, Bool})()
   @ Distributed ~/.julia/juliaup/julia-1.10.3+0.x64.linux.gnu/share/julia/stdlib/v1.10/Distributed/src/process_messages.jl:121

When using Malt, there are not any errors within the terminal

@fonsp
Copy link
Owner

fonsp commented May 9, 2024

Hm! That doesn't help, it just complains that it couldnt read data from a shut down worker.

I made a branch disable-logger-and-stdout where the log and stdout capture is disabled. Can you try on this branch?

pkg (@1.10)> add Pluto#disable-logger-and-stdout

@Dale-Black
Copy link
Author

Interesting, of course the logging is now all in the terminal, but that branch works fine

@fonsp
Copy link
Owner

fonsp commented May 10, 2024

How about Pluto.run(capture_stdout=false)? With regular released Pluto

@Dale-Black
Copy link
Author

Sorry for the late reply, just getting around to this. It looks like that works. I wonder if the fact that I am port-forwarding Pluto on a remote cluster has anything to do with this issue? The current Pluto release with this launching setup now works, although it logs stuff in the terminal of course:

Pluto.run(launch_browser = false, host = "0.0.0.0", capture_stdout = false)

@fonsp fonsp changed the title GPU memory issues when training Lux models in Pluto notebooks Capture_stdout causes crash, GPU, Lux.jl May 23, 2024
@fonsp
Copy link
Owner

fonsp commented May 23, 2024

Strange! That means that capture_stdout is the problem. When running Pluto with capture_stdout=false, is the stdout printed to your terminal different from what you get when running the notebook without Pluto?

I wonder if Lux has a problem with having stdout captured into a buffer (what Pluto does) instead of printing to stdout directly.

Can you try to run your code without Pluto, but using Pluto's stdout capturing system?

Step 1: open a REPL
Step 2: run this code, these are some snippets from Pluto's source code. You could also do import Pluto.PlutoRunner.with_io_to_logs

import Logging


const default_stdout_iocontext = IOContext(devnull, 
    :color => true, 
    :limit => true, 
    :displaysize => (18, 75), 
    :is_pluto => false,
)

const stdout_log_level = Logging.LogLevel(-555) # https://en.wikipedia.org/wiki/555_timer_IC

function _send_stdio_output!(output, loglevel)
    output_str = String(take!(output))
    if !isempty(output_str)
        Logging.@logmsg loglevel output_str
    end
end


function with_io_to_logs(f::Function; enabled::Bool=true, loglevel::Logging.LogLevel=Logging.LogLevel(1))
    if !enabled
        return f()
    end
    # Taken from https://github.com/JuliaDocs/IOCapture.jl/blob/master/src/IOCapture.jl with some modifications to make it log.

    # Original implementation from Documenter.jl (MIT license)
    # Save the default output streams.
    default_stdout = stdout
    default_stderr = stderr
    # Redirect both the `stdout` and `stderr` streams to a single `Pipe` object.
    pipe = Pipe()
    Base.link_pipe!(pipe; reader_supports_async = true, writer_supports_async = true)
    pe_stdout = pipe.in
    pe_stderr = pipe.in
    redirect_stdout(pe_stdout)
    redirect_stderr(pe_stderr)

    # Bytes written to the `pipe` are captured in `output` and eventually converted to a
    # `String`. We need to use an asynchronous task to continously tranfer bytes from the
    # pipe to `output` in order to avoid the buffer filling up and stalling write() calls in
    # user code.
    execution_done = Ref(false)
    output = IOBuffer()

    @async begin
        pipe_reader = Base.pipe_reader(pipe)
        try
            while !eof(pipe_reader)
                write(output, readavailable(pipe_reader))

                # NOTE: we don't really have to wait for the end of execution to stream output logs
                #       so maybe we should just enable it?
                if execution_done[]
                    _send_stdio_output!(output, loglevel)
                end
            end
            _send_stdio_output!(output, loglevel)
        catch err
            @error "Failed to redirect stdout/stderr to logs"  exception=(err,catch_backtrace())
            if err isa InterruptException
                rethrow(err)
            end
        end
    end

    # To make the `display` function work.
    redirect_display = TextDisplay(IOContext(pe_stdout, default_stdout_iocontext))
    pushdisplay(redirect_display)

    # Run the function `f`, capturing all output that it might have generated.
    # Success signals whether the function `f` did or did not throw an exception.
    result = try
        f()
    finally
        # Restore display
        try
            popdisplay(redirect_display)
        catch e
            # This happens when the user calls `popdisplay()`, fine.
            # @warn "Pluto's display was already removed?" e
        end

        execution_done[] = true

        # Restore the original output streams.
        redirect_stdout(default_stdout)
        redirect_stderr(default_stderr)
        close(pe_stdout)
        close(pe_stderr)
    end

    result
end

Step 3: save your notebook as script.jl

Step 4:

with_io_to_logs() do
	include("path/to/script.jl")
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working other packages Integration with other Julia packages
Projects
None yet
Development

No branches or pull requests

2 participants