diff --git a/spanner/client.go b/spanner/client.go index fcba322ac1d..afa35315e29 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -234,7 +234,9 @@ func getQueryOptions(opts QueryOptions) QueryOptions { // Close closes the client. func (c *Client) Close() { if c.idleSessions != nil { - c.idleSessions.close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c.idleSessions.close(ctx) } c.sc.close() } diff --git a/spanner/client_test.go b/spanner/client_test.go index 8a2504706ac..c7ec16bc535 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -2367,3 +2367,36 @@ func TestClient_Single_Read_WithNumericKey(t *testing.T) { t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) } } + +func TestClient_CloseWithUnresponsiveBackend(t *testing.T) { + t.Parallel() + + minOpened := uint64(5) + server, client, teardown := setupMockedTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: minOpened, + }, + }) + defer teardown() + sp := client.idleSessions + + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if uint64(sp.idleList.Len()) != minOpened { + return fmt.Errorf("num open sessions mismatch\nWant: %d\nGot: %d", sp.MinOpened, sp.numOpened) + } + return nil + }) + server.TestSpanner.Freeze() + defer server.TestSpanner.Unfreeze() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + sp.close(ctx) + + if w, g := context.DeadlineExceeded, ctx.Err(); w != g { + t.Fatalf("context error mismatch\nWant: %v\nGot: %v", w, g) + } +} diff --git a/spanner/session.go b/spanner/session.go index 00199cac6a5..d577fca5510 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -328,6 +328,12 @@ func (s *session) recycle() { // destroy removes the session from its home session pool, healthcheck queue // and Cloud Spanner service. func (s *session) destroy(isExpire bool) bool { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + return s.destroyWithContext(ctx, isExpire) +} + +func (s *session) destroyWithContext(ctx context.Context, isExpire bool) bool { // Remove s from session pool. if !s.pool.remove(s, isExpire) { return false @@ -335,8 +341,6 @@ func (s *session) destroy(isExpire bool) bool { // Unregister s from healthcheck queue. s.pool.hc.unregister(s) // Remove s from Cloud Spanner service. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() s.delete(ctx) return true } @@ -345,7 +349,9 @@ func (s *session) delete(ctx context.Context) { // Ignore the error because even if we fail to explicitly destroy the // session, it will be eventually garbage collected by Cloud Spanner. err := s.client.DeleteSession(contextWithOutgoingMetadata(ctx, s.md), &sppb.DeleteSessionRequest{Name: s.getID()}) - if err != nil { + // Do not DeadlineExceeded errors when deleting sessions, as these do not + // indicate anything the user can or should act upon. + if err != nil && ErrCode(err) != codes.DeadlineExceeded { logf(s.logger, "Failed to delete session %v. Error: %v", s.getID(), err) } } @@ -725,8 +731,11 @@ func (p *sessionPool) isValid() bool { return p.valid } -// close marks the session pool as closed. -func (p *sessionPool) close() { +// close marks the session pool as closed and deletes all sessions in parallel. +// Any errors that are returned by the Delete RPC are logged but otherwise +// ignored, except for DeadlineExceeded errors, which are ignored and not +// logged. +func (p *sessionPool) close(ctx context.Context) { if p == nil { return } @@ -743,9 +752,17 @@ func (p *sessionPool) close() { allSessions := make([]*session, len(p.hc.queue.sessions)) copy(allSessions, p.hc.queue.sessions) p.hc.mu.Unlock() + wg := sync.WaitGroup{} for _, s := range allSessions { - s.destroy(false) + wg.Add(1) + go deleteSession(s, ctx, &wg) } + wg.Wait() +} + +func deleteSession(s *session, ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + s.destroyWithContext(ctx, false) } // errInvalidSessionPool is the error for using an invalid session pool. diff --git a/spanner/session_test.go b/spanner/session_test.go index 4f4fdfc1a1e..d8e523c86a6 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -1355,7 +1355,7 @@ func TestSessionHealthCheck(t *testing.T) { if err != nil { t.Fatalf("cannot get session from session pool: %v", err) } - sp.close() + sp.close(context.Background()) if sh.session.isValid() { t.Fatalf("session(%v) is still alive, want it to be garbage collected", s) } @@ -1454,7 +1454,7 @@ func TestStressSessionPool(t *testing.T) { t.Fatalf("%v: session in healthcheck queue (%v) was not found on server", ti, id) } } - sp.close() + sp.close(context.Background()) mockSessions = server.TestSpanner.DumpSessions() for id, b := range hcSessions { if b && mockSessions[id] { @@ -1477,7 +1477,7 @@ func testStressSessionPool(t *testing.T, cfg SessionPoolConfig, ti int, idx int, if idx%10 == 0 && j >= 900 { // Close the pool in selected set of workers during the // middle of the test. - pool.close() + pool.close(context.Background()) } // Take a write sessions ~ 20% of the times. takeWrite := rand.Intn(5) == 4