Skip to content

Commit

Permalink
fix: ensure proxy shuts down gracefully on SIGTERM (#877)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Aug 11, 2021
1 parent 669ad1f commit 9793555
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 42 deletions.
66 changes: 30 additions & 36 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,29 +442,12 @@ func gcloudProject() ([]string, error) {
return []string{cfg.Configuration.Properties.Core.Project}, nil
}

// Main executes the main function of the proxy, allowing it to be called from tests.
//
// Setting timeout to a value greater than 0 causes the process to panic after
// that amount of time. This is to sidestep an issue where sending a Signal to
// the process (via the SSH library) doesn't seem to have an effect, and
// closing the SSH session causes the process to get leaked. This timeout will
// at least cause the proxy to exit eventually.
func Main(timeout time.Duration) {
if timeout > 0 {
go func() {
time.Sleep(timeout)
panic("timeout exceeded")
}()
}
main()
}

func main() {
func runProxy() int {
flag.Parse()

if *version {
fmt.Println("Cloud SQL Auth proxy:", semanticVersion())
return
return 0
}

if *logDebugStdout {
Expand All @@ -479,7 +462,7 @@ func main() {
cleanup, err := logging.EnableStructuredLogs(*logDebugStdout, *verbose)
if err != nil {
logging.Errorf("failed to enable structured logs: %v", err)
os.Exit(1)
return 1
}
defer cleanup()
}
Expand All @@ -501,7 +484,7 @@ func main() {
if *host != "" && !strings.HasSuffix(*host, "/") {
logging.Errorf("Flag host should always end with /")
flag.PrintDefaults()
return
return 0
}

// TODO: needs a better place for consolidation
Expand All @@ -520,36 +503,36 @@ func main() {
logging.Infof("Using gcloud's active project: %v", projList)
} else if gErr, ok := err.(*util.GcloudError); ok && gErr.Status == util.GcloudNotFound {
logging.Errorf("gcloud is not in the path and -instances and -projects are empty")
os.Exit(1)
return 1
} else {
logging.Errorf("unable to retrieve the active gcloud project and -instances and -projects are empty: %v", err)
os.Exit(1)
return 1
}
}

onGCE := metadata.OnGCE()
if err := checkFlags(onGCE); err != nil {
logging.Errorf(err.Error())
os.Exit(1)
return 1
}

ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
client, tokSrc, err := authenticatedClient(ctx)
if err != nil {
logging.Errorf(err.Error())
os.Exit(1)
return 1
}

ins, err := listInstances(ctx, client, projList)
if err != nil {
logging.Errorf(err.Error())
os.Exit(1)
return 1
}
instList = append(instList, ins...)
cfgs, err := CreateInstanceConfigs(*dir, *useFuse, instList, *instanceSrc, client, *skipInvalidInstanceConfigs)
if err != nil {
logging.Errorf(err.Error())
os.Exit(1)
return 1
}

// We only need to store connections in a ConnSet if FUSE is used; otherwise
Expand Down Expand Up @@ -590,7 +573,7 @@ func main() {
hc, err = healthcheck.NewServer(proxyClient, *healthCheckPort)
if err != nil {
logging.Errorf("Could not initialize health check server: %v", err)
os.Exit(1)
return 1
}
defer hc.Close(ctx)
}
Expand All @@ -601,7 +584,7 @@ func main() {
c, fuse, err := fuse.NewConnSrc(*dir, *fuseTmp, proxyClient, connset)
if err != nil {
logging.Errorf("Could not start fuse directory at %q: %v", *dir, err)
os.Exit(1)
return 1
}
connSrc = c
defer fuse.Close()
Expand All @@ -627,7 +610,7 @@ func main() {
c, err := WatchInstances(*dir, cfgs, updates, client)
if err != nil {
logging.Errorf(err.Error())
os.Exit(1)
return 1
}
connSrc = c
}
Expand All @@ -641,7 +624,9 @@ func main() {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)

shutdown := make(chan int, 1)
go func() {
defer func() { cancel(); close(shutdown) }()
<-signals
logging.Infof("Received TERM signal. Waiting up to %s before terminating.", *termTimeout)
go func() {
Expand All @@ -651,11 +636,11 @@ func main() {
}()

err := proxyClient.Shutdown(*termTimeout)
if err == nil {
os.Exit(0)
if err != nil {
logging.Errorf("Error during SIGTERM shutdown: %v", err)
shutdown <- 2
return
}
logging.Errorf("Error during SIGTERM shutdown: %v", err)
os.Exit(2)
}()

// If running under systemd with Type=notify, we'll send a message to the
Expand All @@ -666,5 +651,14 @@ func main() {
logging.Errorf("Failed to notify systemd of readiness: %v", err)
}
}()
proxyClient.Run(connSrc)
proxyClient.RunContext(ctx, connSrc)
if code, ok := <-shutdown; ok {
return code
}
return 0
}

func main() {
code := runProxy()
os.Exit(code)
}
2 changes: 1 addition & 1 deletion proxy/fuse/fuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func NewConnSrc(mountdir, tmpdir string, client *proxy.Client, connset *proxy.Co
c, err := fuse.Mount(mountdir, fuse.AllowOther())
if err != nil {
// a common cause of failed mounts is that a previous instance did not shutdown cleanly, leaving an abandoned mount
logging.Errorf("WARNING: Mount failed - attempting to unmount dir to resolve...", mountdir)
logging.Errorf("WARNING: Mount failed - attempting to unmount dir to resolve..., dir=%v", mountdir)
if err = fuse.Unmount(mountdir); err != nil {
logging.Errorf("Unmount failed: %v", err)
}
Expand Down
25 changes: 21 additions & 4 deletions proxy/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,33 @@ type cacheEntry struct {
// Run causes the client to start waiting for new connections to connSrc and
// proxy them to the destination instance. It blocks until connSrc is closed.
func (c *Client) Run(connSrc <-chan Conn) {
for conn := range connSrc {
go c.handleConn(conn)
c.RunContext(context.Background(), connSrc)
}

func (c *Client) run(ctx context.Context, connSrc <-chan Conn) {
for {
select {
case conn, ok := <-connSrc:
if !ok {
return
}
go c.handleConn(ctx, conn)
case <-ctx.Done():
return
}
}
}

// RunContext is like Run with an additional context.Context argument.
func (c *Client) RunContext(ctx context.Context, connSrc <-chan Conn) {
c.run(ctx, connSrc)

if err := c.Conns.Close(); err != nil {
logging.Errorf("closing client had error: %v", err)
}
}

func (c *Client) handleConn(conn Conn) {
func (c *Client) handleConn(ctx context.Context, conn Conn) {
active := atomic.AddUint64(&c.ConnectionsCounter, 1)

// Deferred decrement of ConnectionsCounter upon connection closing
Expand All @@ -161,7 +178,7 @@ func (c *Client) handleConn(conn Conn) {
return
}

server, err := c.Dial(conn.Instance)
server, err := c.DialContext(ctx, conn.Instance)
if err != nil {
logging.Errorf("couldn't connect to %q: %v", conn.Instance, err)
conn.Conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion proxy/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestMaximumConnectionsCount(t *testing.T) {
Instance: instanceName,
Conn: &dummyConn{},
}
c.handleConn(conn)
c.handleConn(context.Background(), conn)

firstDialOnce.Do(func() { close(firstDialExited) })
}(instanceName)
Expand Down

0 comments on commit 9793555

Please sign in to comment.