Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spanner/spannertest): support queries in ExecuteSql #3640

Merged
merged 9 commits into from Feb 4, 2021
95 changes: 74 additions & 21 deletions spanner/spannertest/inmem.go
Expand Up @@ -485,9 +485,21 @@ func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.Tra
}

func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
// Assume this is probably a DML statement. Queries tend to use ExecuteStreamingSql.
// Assume this is probably a DML statement or a ping from the session pool.
// Queries normally use ExecuteStreamingSql.
// TODO: Expand this to support more things.

// If it is a single-use transaction we assume it is a query.
if req.Transaction == nil ||
olavloite marked this conversation as resolved.
Show resolved Hide resolved
req.Transaction.GetSelector() == nil ||
(req.Transaction.GetSingleUse() != nil && req.Transaction.GetSingleUse().GetReadOnly() != nil) {
ri, err := s.executeQuery(req)
if err != nil {
return nil, err
}
return s.resultSet(ri)
}

obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id)
if !ok {
return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector)
Expand Down Expand Up @@ -527,27 +539,31 @@ func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream sp
}
defer cleanup()

ri, err := s.executeQuery(req)
if err != nil {
return err
}
return s.readStream(stream.Context(), tx, stream.Send, ri)
}

func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) {
q, err := spansql.ParseQuery(req.Sql)
if err != nil {
// TODO: check what code the real Spanner returns here.
return status.Errorf(codes.InvalidArgument, "bad query: %v", err)
return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err)
}

params, err := parseQueryParams(req.GetParams(), req.ParamTypes)
if err != nil {
return err
return nil, err
}

s.logf("Querying: %s", q.SQL())
if len(params) > 0 {
s.logf(" ▹ %v", params)
}

ri, err := s.db.Query(q, params)
if err != nil {
return err
}
return s.readStream(stream.Context(), tx, stream.Send, ri)
return s.db.Query(q, params)
}

// TODO: Read
Expand Down Expand Up @@ -591,21 +607,39 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span
return s.readStream(stream.Context(), tx, stream.Send, ri)
}

func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
// Build the result set metadata.
rsm := &spannerpb.ResultSetMetadata{
RowType: &spannerpb.StructType{},
// TODO: transaction info?
func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) {
rsm, err := s.buildResultSetMetaData(ri)
if err != nil {
return nil, err
}
for _, ci := range ri.Cols() {
st, err := spannerTypeFromType(ci.Type)
if err != nil {
return err
rs := &spannerpb.ResultSet{
Metadata: rsm,
}
for {
row, err := ri.Next()
if err == io.EOF {
break
} else if err != nil {
return nil, err
}
rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
Name: string(ci.Name),
Type: st,
})

values := make([]*structpb.Value, len(row))
for i, x := range row {
v, err := spannerValueFromValue(x)
if err != nil {
return nil, err
}
values[i] = v
}
rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values})
}
return rs, nil
}

func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error {
rsm, err := s.buildResultSetMetaData(ri)
if err != nil {
return err
}

for {
Expand Down Expand Up @@ -640,6 +674,25 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa
return nil
}

func (s *server) buildResultSetMetaData(ri rowIter) (*spannerpb.ResultSetMetadata, error) {
olavloite marked this conversation as resolved.
Show resolved Hide resolved
// Build the result set metadata.
rsm := &spannerpb.ResultSetMetadata{
RowType: &spannerpb.StructType{},
// TODO: transaction info?
}
for _, ci := range ri.Cols() {
st, err := spannerTypeFromType(ci.Type)
if err != nil {
return nil, err
}
rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{
Name: string(ci.Name),
Type: st,
})
}
return rsm, nil
}

func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
//s.logf("BeginTransaction(%v)", req)

Expand Down
46 changes: 41 additions & 5 deletions spanner/spannertest/integration_test.go
Expand Up @@ -36,6 +36,7 @@ import (
"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
dbadmin "cloud.google.com/go/spanner/admin/database/apiv1"
v1 "cloud.google.com/go/spanner/apiv1"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc"
Expand All @@ -56,7 +57,7 @@ func dbName() string {
return "projects/fake-proj/instances/fake-instance/databases/fake-db"
}

func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, func()) {
func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, *v1.Client, func()) {
// Despite the docs, this context is also used for auth,
// so it needs to be long-lived.
ctx := context.Background()
Expand All @@ -73,7 +74,13 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu
client.Close()
t.Fatalf("Connecting DB admin client: %v", err)
}
return client, adminClient, func() { client.Close(); adminClient.Close() }
gapicClient, err := v1.NewClient(ctx, dialOpt)
if err != nil {
client.Close()
adminClient.Close()
t.Fatalf("Connecting Spanner generated client: %v", err)
}
return client, adminClient, gapicClient, func() { client.Close(); adminClient.Close(); gapicClient.Close() }
}

// Don't use SPANNER_EMULATOR_HOST because we need the raw connection for
Expand Down Expand Up @@ -102,16 +109,23 @@ func makeClient(t *testing.T) (*spanner.Client, *dbadmin.DatabaseAdminClient, fu
srv.Close()
t.Fatalf("Connecting to in-memory fake DB admin: %v", err)
}
return client, adminClient, func() {
gapicClient, err := v1.NewClient(ctx, option.WithGRPCConn(conn))
if err != nil {
srv.Close()
t.Fatalf("Connecting to in-memory fake generated Spanner client: %v", err)
}

return client, adminClient, gapicClient, func() {
client.Close()
adminClient.Close()
gapicClient.Close()
conn.Close()
srv.Close()
}
}

func TestIntegration_SpannerBasics(t *testing.T) {
client, adminClient, cleanup := makeClient(t)
client, adminClient, generatedClient, cleanup := makeClient(t)
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
Expand All @@ -137,6 +151,28 @@ func TestIntegration_SpannerBasics(t *testing.T) {
}
it.Stop()

// Try to execute the equivalent of a session pool ping.
olavloite marked this conversation as resolved.
Show resolved Hide resolved
s, err := generatedClient.CreateSession(ctx, &spannerpb.CreateSessionRequest{Database: dbName()})
if err != nil {
t.Fatalf("Creating session: %v", err)
}
rs, err := generatedClient.ExecuteSql(ctx, &spannerpb.ExecuteSqlRequest{
Session: s.Name,
Sql: "SELECT 1",
})
if len(rs.Rows) != 1 {
olavloite marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("Ping gave %v rows, want 1", len(rs.Rows))
}
if len(rs.Rows[0].Values) != 1 {
t.Fatalf("Ping gave %v cols, want 1", len(rs.Rows[0].Values))
}
if rs.Rows[0].Values[0].GetStringValue() != "1" {
t.Fatalf("Ping gave value %v, want '1'", rs.Rows[0].Values[0].GetStringValue())
}
if err = generatedClient.DeleteSession(ctx, &spannerpb.DeleteSessionRequest{Name: s.Name}); err != nil {
t.Fatalf("Deleting session: %v", err)
}

// Drop any previous test table/index, and make a fresh one in a few stages.
const tableName = "Characters"
err = updateDDL(t, adminClient, "DROP INDEX AgeIndex")
Expand Down Expand Up @@ -400,7 +436,7 @@ func TestIntegration_SpannerBasics(t *testing.T) {
}

func TestIntegration_ReadsAndQueries(t *testing.T) {
client, adminClient, cleanup := makeClient(t)
client, adminClient, _, cleanup := makeClient(t)
defer cleanup()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
Expand Down