From b951d8bd194b76da0a8bf2ce7cf85b546d2e051c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 4 Nov 2020 15:10:39 +0100 Subject: [PATCH] feat(spanner): export ToSpannerError (#3133) * feat: export ToSpannerError The ToSpannerError method should be exported to allow users to create Spanner errors to use with unit tests. Currently, that was only possible by creating a struct directly and filling the deprecated Code field. That is however no longer recommended, as the code field might be removed in a future release. Fixes #3122 * fix: use New instead of Newf --- spanner/client.go | 4 ++-- spanner/client_test.go | 2 +- spanner/errors.go | 11 +++++++++-- spanner/errors_test.go | 2 +- spanner/mutation.go | 2 +- spanner/pdml.go | 6 +++--- spanner/read.go | 2 +- spanner/session.go | 6 +++--- spanner/sessionclient.go | 6 +++--- spanner/transaction.go | 8 ++++---- spanner/transaction_test.go | 4 ++-- spanner/value.go | 2 +- 12 files changed, 31 insertions(+), 24 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index a4ffbb8bdcc..c867afc71b1 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -117,7 +117,7 @@ type ClientConfig struct { // errDial returns error for dialing to Cloud Spanner. func errDial(ci int, err error) error { - e := toSpannerError(err).(*Error) + e := ToSpannerError(err).(*Error) e.decorate(fmt.Sprintf("dialing fails for channel[%v]", ci)) return e } @@ -341,7 +341,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound }, }) if err != nil { - return nil, toSpannerError(err) + return nil, ToSpannerError(err) } tx = res.Id if res.ReadTimestamp != nil { diff --git a/spanner/client_test.go b/spanner/client_test.go index 7ea14f0d07b..92cac5564bb 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -537,7 +537,7 @@ func TestClient_ReadOnlyTransaction_SessionNotFoundOnExecuteStreamingSql(t *test err := testReadOnlyTransaction(t, map[string]SimulatedExecutionTime{ MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, }) - want := toSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) + want := ToSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) if err == nil { t.Fatalf("missing expected error\nGot: nil\nWant: %v", want) } diff --git a/spanner/errors.go b/spanner/errors.go index d4d159b5b8a..98de741662c 100644 --- a/spanner/errors.go +++ b/spanner/errors.go @@ -115,8 +115,15 @@ func spannerErrorf(code codes.Code, format string, args ...interface{}) error { } } -// toSpannerError converts general Go error to *spanner.Error. -func toSpannerError(err error) error { +// ToSpannerError converts a general Go error to *spanner.Error. If the given +// error is already a *spanner.Error, the original error will be returned. +// +// Spanner Errors are normally created by the Spanner client library from the +// returned status of a RPC. This method can also be used to create Spanner +// errors for use in tests. The recommended way to create test errors is +// calling this method with a status error, e.g. +// ToSpannerError(status.New(codes.NotFound, "Table not found").Err()) +func ToSpannerError(err error) error { return toSpannerErrorWithCommitInfo(err, false) } diff --git a/spanner/errors_test.go b/spanner/errors_test.go index c1e90f738c4..42d95aaa98f 100644 --- a/spanner/errors_test.go +++ b/spanner/errors_test.go @@ -63,7 +63,7 @@ func TestToSpannerError(t *testing.T) { wrapped: errors.New("wha?"), msg: "error with wrapped non-gRPC and non-Spanner error"}}, } { - err := toSpannerError(test.err) + err := ToSpannerError(test.err) errDuringCommit := toSpannerErrorWithCommitInfo(test.err, true) if got, want := ErrCode(err), test.wantCode; got != want { t.Errorf("%v: got %s, want %s", test.err, got, want) diff --git a/spanner/mutation.go b/spanner/mutation.go index 6d6e41ce9eb..0d64d9b04d9 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -179,7 +179,7 @@ func structToMutationParams(in interface{}) ([]string, []interface{}, error) { } fields, err := fieldCache.Fields(t) if err != nil { - return nil, nil, toSpannerError(err) + return nil, nil, ToSpannerError(err) } var cols []string var vals []interface{} diff --git a/spanner/pdml.go b/spanner/pdml.go index e50331235ab..52e56ab1d7f 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -51,7 +51,7 @@ func (c *Client) partitionedUpdate(ctx context.Context, statement Statement, opt sh, err := c.idleSessions.take(ctx) if err != nil { - return 0, toSpannerError(err) + return 0, ToSpannerError(err) } if sh != nil { defer sh.recycle() @@ -61,7 +61,7 @@ func (c *Client) partitionedUpdate(ctx context.Context, statement Statement, opt // The transaction reference will be added by the executePdml method. params, paramTypes, err := statement.convertParams() if err != nil { - return 0, toSpannerError(err) + return 0, ToSpannerError(err) } req := &sppb.ExecuteSqlRequest{ Session: sh.getID(), @@ -107,7 +107,7 @@ func executePdml(ctx context.Context, sh *sessionHandle, req *sppb.ExecuteSqlReq }, }) if err != nil { - return 0, toSpannerError(err) + return 0, ToSpannerError(err) } // Add a reference to the PDML transaction on the ExecuteSql request. req.Transaction = &sppb.TransactionSelector{ diff --git a/spanner/read.go b/spanner/read.go index 296e1036f18..eaf8c339e02 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -157,7 +157,7 @@ func (r *RowIterator) Next() (*Row, error) { return row, nil } if err := r.streamd.lastErr(); err != nil { - r.err = toSpannerError(err) + r.err = ToSpannerError(err) } else if !r.rowd.done() { r.err = errEarlyReadEnd() } else { diff --git a/spanner/session.go b/spanner/session.go index f8935725929..00199cac6a5 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -1033,13 +1033,13 @@ func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, err s.destroy(false) trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Session not found for write") - return nil, toSpannerError(err) + return nil, ToSpannerError(err) } s.recycle() trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Error preparing session for write") - return nil, toSpannerError(err) + return nil, ToSpannerError(err) } } p.incNumInUse(ctx) @@ -1506,7 +1506,7 @@ func (hc *healthChecker) worker(i int) { // cycle. // Don't log about permission errors, which may be expected // (e.g. using read-only auth). - serr := toSpannerError(err).(*Error) + serr := ToSpannerError(err).(*Error) if serr.Code != codes.PermissionDenied { logf(hc.pool.sc.logger, "Failed to prepare session, error: %v", serr) } diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index a1c9f66d805..4646db577ef 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -135,7 +135,7 @@ func (sc *sessionClient) createSession(ctx context.Context) (*session, error) { Session: &sppb.Session{Labels: sc.sessionLabels}, }) if err != nil { - return nil, toSpannerError(err) + return nil, ToSpannerError(err) } return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md, logger: sc.logger}, nil } @@ -227,7 +227,7 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC } if ctx.Err() != nil { trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err()) - consumer.sessionCreationFailed(toSpannerError(ctx.Err()), remainingCreateCount) + consumer.sessionCreationFailed(ToSpannerError(ctx.Err()), remainingCreateCount) break } response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{ @@ -237,7 +237,7 @@ func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createC }) if err != nil { trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err) - consumer.sessionCreationFailed(toSpannerError(err), remainingCreateCount) + consumer.sessionCreationFailed(ToSpannerError(err), remainingCreateCount) break } actuallyCreated := int32(len(response.Session)) diff --git a/spanner/transaction.go b/spanner/transaction.go index 93087b5a9e8..cd7ea25bcaf 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -489,7 +489,7 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) } } else { - err = toSpannerError(err) + err = ToSpannerError(err) } break } @@ -827,7 +827,7 @@ func (t *ReadWriteTransaction) update(ctx context.Context, stmt Statement, opts } resultSet, err := sh.getClient().ExecuteSql(contextWithOutgoingMetadata(ctx, sh.getMetadata()), req) if err != nil { - return 0, toSpannerError(err) + return 0, ToSpannerError(err) } if resultSet.Stats == nil { return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", stmt.SQL) @@ -876,7 +876,7 @@ func (t *ReadWriteTransaction) BatchUpdate(ctx context.Context, stmts []Statemen Seqno: atomic.AddInt64(&t.sequenceNumber, 1), }) if err != nil { - return nil, toSpannerError(err) + return nil, ToSpannerError(err) } var counts []int64 @@ -1209,7 +1209,7 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta break } } - return ts, toSpannerError(err) + return ts, ToSpannerError(err) } // isAbortedErr returns true if the error indicates that an gRPC call is diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 07cdc24d541..686e1c81032 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -78,7 +78,7 @@ func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { }) _, _, e := txn.acquire(ctx) - if wantErr := toSpannerError(errUsr); !testEqual(e, wantErr) { + if wantErr := ToSpannerError(errUsr); !testEqual(e, wantErr) { t.Fatalf("Acquire for multi use, got %v, want %v.", e, wantErr) } _, _, e = txn.acquire(ctx) @@ -238,7 +238,7 @@ func TestTransaction_SessionNotFound(t *testing.T) { t.Fatalf("Expect Read to succeed, got %v, want %v.", got.err, wantErr) } - wantErr = toSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) + wantErr = ToSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), diff --git a/spanner/value.go b/spanner/value.go index 779ec497282..8ed147e36de 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -2389,7 +2389,7 @@ func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) er fields, err := fieldCache.Fields(t) if err != nil { - return toSpannerError(err) + return ToSpannerError(err) } seen := map[string]bool{} for i, f := range ty.Fields {