Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Cmd IO handling #1937

Merged
merged 1 commit into from Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 42 additions & 11 deletions internal/cmd/cmd.go
Expand Up @@ -5,6 +5,7 @@ package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
Expand All @@ -20,6 +21,8 @@ import (
"golang.org/x/sys/windows"
)

var errIOTimeOut = errors.New("timed out waiting for stdio relay")

// CmdProcessRequest stores information on command requests made through this package.
type CmdProcessRequest struct {
Args []string
Expand Down Expand Up @@ -62,7 +65,7 @@ type Cmd struct {
// ExitState is filled out after Wait() (or Run() or Output()) completes.
ExitState *ExitState

iogrp errgroup.Group
ioGrp errgroup.Group
stdinErr atomic.Value
allDoneCh chan struct{}
}
Expand Down Expand Up @@ -90,13 +93,13 @@ func (err *ExitError) Error() string {
return fmt.Sprintf("process exited with exit code %d", err.ExitCode())
}

// Additional fields to hcsschema.ProcessParameters used by LCOW
// Additional fields to hcsschema.ProcessParameters used by LCOW.
type lcowProcessParameters struct {
hcsschema.ProcessParameters
OCIProcess *specs.Process `json:"OciProcess,omitempty"`
}

// escapeArgs makes a Windows-style escaped command line from a set of arguments
// escapeArgs makes a Windows-style escaped command line from a set of arguments.
func escapeArgs(args []string) string {
escapedArgs := make([]string, len(args))
for i, a := range args {
Expand Down Expand Up @@ -136,9 +139,19 @@ func CommandContext(ctx context.Context, host cow.ProcessHost, name string, arg
// Start starts a command. The caller must ensure that if Start succeeds,
// Wait is eventually called to clean up resources.
func (c *Cmd) Start() error {
if c.Host == nil {
return errors.New("empty ProcessHost")
}

// closed in (*Cmd).Wait; signals command execution is done
c.allDoneCh = make(chan struct{})

var x interface{}
if !c.Host.IsOCI() {
if c.Spec == nil {
return errors.New("process spec is required for non-OCI ProcessHost")
}

wpp := &hcsschema.ProcessParameters{
CommandLine: c.Spec.CommandLine,
User: c.Spec.User.Username,
Expand Down Expand Up @@ -199,6 +212,7 @@ func (c *Cmd) Start() error {
// Start relaying process IO.
stdin, stdout, stderr := p.Stdio()
if c.Stdin != nil {
c.Log.Info("coping stdin")
// Do not make stdin part of the error group because there is no way for
// us or the caller to reliably unblock the c.Stdin read when the
// process exits.
Expand All @@ -218,20 +232,20 @@ func (c *Cmd) Start() error {
}

if c.Stdout != nil {
c.iogrp.Go(func() error {
c.ioGrp.Go(func() error {
_, err := relayIO(c.Stdout, stdout, c.Log, "stdout")
if err := p.CloseStdout(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stdout")
if cErr := p.CloseStdout(context.TODO()); cErr != nil && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stdout")
}
return err
})
}

if c.Stderr != nil {
c.iogrp.Go(func() error {
c.ioGrp.Go(func() error {
_, err := relayIO(c.Stderr, stderr, c.Log, "stderr")
if err := p.CloseStderr(context.TODO()); err != nil {
c.Log.WithError(err).Warn("failed to close Cmd stderr")
if cErr := p.CloseStderr(context.TODO()); cErr != nil && c.Log != nil {
c.Log.WithError(cErr).Warn("failed to close Cmd stderr")
}
return err
})
Expand Down Expand Up @@ -270,27 +284,44 @@ func (c *Cmd) Wait() error {
state.exited = true
state.code = code
}

// Terminate the IO if the copy does not complete in the requested time.
// Closing the process should (eventually) lead to unblocking `ioGrp`, but we still need
// `timeoutErrCh` to:
// 1. communitate that the IO copy timed out; and
// 2. prevent a race condition between setting the timeout err in the goroutine and setting it for `ioErr`.
timeoutErrCh := make(chan error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a separate channel for timeout error, can we just have a separate error variable that will be set only in the <-t.C: block? That way we don't need to worry about closing the channel and the following code block:

if tErr := <-timeoutErrCh; ioErr == nil {
    ioErr = tErr
}

can just be

if ioErr == nil {
    ioErr = tErr
}

This approach seems easier to read/understand for me at least. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to block on timeoutErrCh regardless, so we need some mechanism to have Wait actually wait for the timeout

the alternatives (imo) would be to add a function that selects on <-allDoneCh and ctx.Done(), and pass it a context with the CopyAfterExitTimeout or something similar, which i would rather save for when we refactor Cmd wholesale to properly thread contexts everywhere

if c.CopyAfterExitTimeout != 0 {
go func() {
defer close(timeoutErrCh)
t := time.NewTimer(c.CopyAfterExitTimeout)
defer t.Stop()
select {
case <-c.allDoneCh:
case <-t.C:
// Close the process to cancel any reads to stdout or stderr.
c.Process.Close()
err := errIOTimeOut
// log the timeout, since we may not return it to the caller
if c.Log != nil {
c.Log.Warn("timed out waiting for stdio relay")
c.Log.WithField("timeout", c.CopyAfterExitTimeout).Warn(err.Error())
}
timeoutErrCh <- err
}
}()
} else {
close(timeoutErrCh)
}
ioErr := c.iogrp.Wait()

// TODO (go1.20): use multierror for these
ioErr := c.ioGrp.Wait()
if ioErr == nil {
ioErr, _ = c.stdinErr.Load().(error)
}
close(c.allDoneCh)
if tErr := <-timeoutErrCh; ioErr == nil {
helsaawy marked this conversation as resolved.
Show resolved Hide resolved
ioErr = tErr
}
c.Process.Close()
c.ExitState = state
if exitErr != nil {
Expand Down
56 changes: 40 additions & 16 deletions internal/cmd/cmd_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -213,46 +214,69 @@ func TestCmdStdinBlocked(t *testing.T) {
}
}

type stuckIoProcessHost struct {
type stuckIOProcessHost struct {
cow.ProcessHost
}

type stuckIoProcess struct {
type stuckIOProcess struct {
cow.Process
stdin, pstdout, pstderr *io.PipeWriter
pstdin, stdout, stderr *io.PipeReader

// don't initialize p.stdin, since it complicates the logic
pstdout, pstderr *os.File
stdout, stderr *os.File
}

func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
func (h *stuckIOProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
p, err := h.ProcessHost.CreateProcess(ctx, cfg)
if err != nil {
return nil, err
}
sp := &stuckIoProcess{
sp := &stuckIOProcess{
Process: p,
}
sp.pstdin, sp.stdin = io.Pipe()
sp.stdout, sp.pstdout = io.Pipe()
sp.stderr, sp.pstderr = io.Pipe()

if sp.stdout, sp.pstdout, err = os.Pipe(); err != nil {
return nil, fmt.Errorf("create stdout pipe: %w", err)
}
if sp.stderr, sp.pstderr, err = os.Pipe(); err != nil {
return nil, fmt.Errorf("create stderr pipe: %w", err)
}
return sp, nil
}

func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
return p.stdin, p.stdout, p.stderr
func (p *stuckIOProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
return nil, p.stdout, p.stderr
}

func (p *stuckIoProcess) Close() error {
p.stdin.Close()
func (*stuckIOProcess) CloseStdin(context.Context) error {
return nil
}

func (p *stuckIOProcess) CloseStdout(context.Context) error {
_ = p.pstdout.Close()
return p.stdout.Close()
}

func (p *stuckIOProcess) CloseStderr(context.Context) error {
_ = p.pstderr.Close()
return p.stderr.Close()
}

func (p *stuckIOProcess) Close() error {
p.pstdout.Close()
p.pstderr.Close()

p.stdout.Close()
p.stderr.Close()

return p.Process.Close()
}

func TestCmdStuckIo(t *testing.T) {
cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
cmd := Command(&stuckIOProcessHost{&localProcessHost{}}, "cmd", "/c", "(exit 0)")
cmd.CopyAfterExitTimeout = time.Millisecond * 200
_, err := cmd.Output()
if err != io.ErrClosedPipe { //nolint:errorlint
t.Fatal(err)
if !errors.Is(err, errIOTimeOut) {
t.Fatalf("expected: %v; got: %v", errIOTimeOut, err)
}
}
4 changes: 2 additions & 2 deletions internal/cmd/io.go
Expand Up @@ -4,11 +4,11 @@ package cmd

import (
"context"
"fmt"
"io"
"net/url"
"time"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -57,7 +57,7 @@ func NewUpstreamIO(ctx context.Context, id, stdout, stderr, stdin string, termin

// Create IO for binary logging driver.
if u.Scheme != "binary" {
return nil, errors.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
return nil, fmt.Errorf("scheme must be 'binary', got: '%s'", u.Scheme)
}

return NewBinaryIO(ctx, id, u)
Expand Down