Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spanner/spannertest): fix ORDER BY combined with SELECT aliases #3043

Merged
merged 1 commit into from Oct 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion spanner/spannertest/README.md
Expand Up @@ -19,10 +19,10 @@ by ascending esotericism:

- expression functions
- more aggregation functions
- more joins types (INNER, CROSS, FULL, RIGHT)
- INSERT/UPDATE DML statements
- SELECT HAVING
- case insensitivity
- FULL JOIN
- alternate literal types (esp. strings)
- STRUCT types
- transaction simulation
Expand Down
257 changes: 172 additions & 85 deletions spanner/spannertest/db_query.go
Expand Up @@ -191,10 +191,27 @@ type selIter struct {
ec evalContext
cis []colInfo
list []spansql.Expr

distinct bool // whether this is a SELECT DISTINCT
seen []row
}

func (si selIter) Cols() []colInfo { return si.cis }
func (si selIter) Next() (row, error) {
func (si *selIter) Cols() []colInfo { return si.cis }
func (si *selIter) Next() (row, error) {
for {
r, err := si.next()
if err != nil {
return nil, err
}
if si.distinct && !si.keep(r) {
continue
}
return r, nil
}
}

// next retrieves the next row for the SELECT and evaluates its expression list.
func (si *selIter) next() (row, error) {
r, err := si.ri.Next()
if err != nil {
return nil, err
Expand All @@ -216,35 +233,17 @@ func (si selIter) Next() (row, error) {
return out, nil
}

// distinctIter applies a DISTINCT filter.
type distinctIter struct {
ri rowIter
seen []row
}

func (di *distinctIter) Cols() []colInfo { return di.ri.Cols() }
func (di *distinctIter) Next() (row, error) {
func (si *selIter) keep(r row) bool {
// This is hilariously inefficient; O(N^2) in the number of returned rows.
// Some sort of hashing could be done to deduplicate instead.
// This also breaks on array/struct types.
for {
row, err := di.ri.Next()
if err != nil {
return nil, err
for _, prev := range si.seen {
if rowEqual(prev, r) {
return false
}
dupe := false
for _, prev := range di.seen {
if rowEqual(prev, row) {
dupe = true
break
}
}
if dupe {
continue
}
di.seen = append(di.seen, row)
return row, nil
}
si.seen = append(si.seen, r)
return true
}

// offsetIter applies an OFFSET clause.
Expand Down Expand Up @@ -295,39 +294,70 @@ type queryParam struct {

type queryParams map[string]queryParam // TODO: change key to spansql.Param?

func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) {
// If there's an ORDER BY clause, extend the query to include the expressions we need
// so they get evaluated during evalSelect. TODO: Is this actually okay?
type queryContext struct {
params queryParams

tables []*table // sorted by name
tableIndex map[spansql.ID]*table
locks int
}

func (qc *queryContext) Lock() {
// Take locks in name order.
for _, t := range qc.tables {
t.mu.Lock()
qc.locks++
}
}

func (qc *queryContext) Unlock() {
for _, t := range qc.tables {
t.mu.Unlock()
qc.locks--
}
}

func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) {
// Figure out the context of the query and take any required locks.
qc, err := d.queryContext(q, params)
if err != nil {
return nil, err
}
qc.Lock()
// On the way out, if there were locks taken, flatten the output
// and release the locks.
if qc.locks > 0 {
defer func() {
if err == nil {
ri, err = toRawIter(ri)
}
qc.Unlock()
}()
}

// Prepare auxiliary expressions to evaluate for ORDER BY.
var aux []spansql.Expr
var desc []bool
for _, o := range q.Order {
aux = append(aux, o.Expr)
desc = append(desc, o.Desc)
}
q.Select.List = append(q.Select.List, aux...)

ri, err := d.evalSelect(q.Select, params)
si, err := d.evalSelect(q.Select, qc)
if err != nil {
return nil, err
}
ri = si

// Apply ORDER BY.
if len(q.Order) > 0 {
raw, err := toRawIter(ri)
// Evaluate the selIter completely, and sort the rows by the auxiliary expressions.
rows, keys, err := evalSelectOrder(si, aux)
if err != nil {
return nil, err
}
sort.Slice(raw.rows, func(one, two int) bool {
r1, r2 := raw.rows[one], raw.rows[two]
aux1, aux2 := r1[len(r1)-len(aux):], r2[len(r2)-len(aux):] // sort keys
return compareValLists(aux1, aux2, desc) < 0
})
// Remove ORDER BY values.
raw.cols = raw.cols[:len(raw.cols)-len(aux)]
for i, row := range raw.rows {
raw.rows[i] = row[:len(row)-len(aux)]
}
ri = raw
sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc})
ri = &rawIter{cols: si.cis, rows: rows}
}

// Apply LIMIT, OFFSET.
Expand All @@ -350,33 +380,76 @@ func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) {
return ri, nil
}

func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIter, evalErr error) {
ri = &nullIter{}
ec := evalContext{
func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) {
qc := &queryContext{
params: params,
}

// Look for any mentioned tables and add them to qc.tableIndex.
addTable := func(name spansql.ID) error {
if _, ok := qc.tableIndex[name]; ok {
return nil // Already found this table.
}
t, err := d.table(name)
if err != nil {
return err
}
if qc.tableIndex == nil {
qc.tableIndex = make(map[spansql.ID]*table)
}
qc.tableIndex[name] = t
return nil
}
var findTables func(sf spansql.SelectFrom) error
findTables = func(sf spansql.SelectFrom) error {
switch sf := sf.(type) {
default:
return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf)
case spansql.SelectFromTable:
return addTable(sf.Table)
case spansql.SelectFromJoin:
if err := findTables(sf.LHS); err != nil {
return err
}
return findTables(sf.RHS)
}
}
for _, sf := range q.Select.From {
if err := findTables(sf); err != nil {
return nil, err
}
}

// Build qc.tables in name order so we can take locks in a well-defined order.
var names []spansql.ID
for name := range qc.tableIndex {
names = append(names, name)
}
sort.Slice(names, func(i, j int) bool { return names[i] < names[j] })
for _, name := range names {
qc.tables = append(qc.tables, qc.tableIndex[name])
}

return qc, nil
}

func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) {
var ri rowIter = &nullIter{}
ec := evalContext{
dsymonds marked this conversation as resolved.
Show resolved Hide resolved
params: qc.params,
}

// First stage is to identify the data source.
// If there's a FROM then that names a table to use.
if len(sel.From) > 1 {
return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported")
}
if len(sel.From) == 1 {
var unlock func()
var err error
ec, ri, unlock, err = d.evalSelectFrom(ec, sel.From[0])
ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0])
if err != nil {
return nil, err
}
defer unlock()

// On the way out, convert the result to a rawIter
// so that any locked tables may be safely unlocked.
defer func() {
if evalErr == nil {
ri, evalErr = toRawIter(ri)
}
}()
}

// Apply WHERE.
Expand Down Expand Up @@ -577,31 +650,27 @@ func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIte
colInfos = append(colInfos, ci)
}
}
ri = selIter{

return &selIter{
ri: ri,
ec: ec,
cis: colInfos,
list: sel.List,
}

// Apply DISTINCT.
if sel.Distinct {
ri = &distinctIter{ri: ri}
}

return ri, nil
distinct: sel.Distinct, // Apply DISTINCT.
}, nil
}

func (d *database) evalSelectFrom(ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, func(), error) {
func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) {
switch sf := sf.(type) {
default:
return ec, nil, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
case spansql.SelectFromTable:
t, err := d.table(sf.Table)
if err != nil {
return ec, nil, nil, err
t, ok := qc.tableIndex[sf.Table]
if !ok {
// This shouldn't be possible; the queryContext should have discovered missing tables already.
return ec, nil, fmt.Errorf("unknown table %q", sf.Table)
}
t.mu.Lock()
ti := &tableIter{t: t}
if sf.Alias != "" {
ti.alias = sf.Alias
Expand All @@ -611,36 +680,33 @@ func (d *database) evalSelectFrom(ec evalContext, sf spansql.SelectFrom) (evalCo
ti.alias = sf.Table
}
ec.cols = ti.Cols()
return ec, ti, t.mu.Unlock, nil
return ec, ti, nil
case spansql.SelectFromJoin:
// TODO: Avoid the toRawIter calls here by rethinking how locking works throughout evalSelect,
// then doing the RHS recursive evalSelectFrom in joinIter.Next on demand.
// TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand.

lhsEC, lhs, unlock, err := d.evalSelectFrom(ec, sf.LHS)
lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
lhsRaw, err := toRawIter(lhs)
unlock()
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}

rhsEC, rhs, unlock, err := d.evalSelectFrom(ec, sf.RHS)
rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
rhsRaw, err := toRawIter(rhs)
unlock()
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}

ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
return ec, ji, func() {}, nil
return ec, ji, nil
}
}

Expand Down Expand Up @@ -893,16 +959,37 @@ func (ji *joinIter) Next() (row, error) {
}
}

func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) {
// This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY.
for {
r, err := si.Next()
if err == io.EOF {
break
} else if err != nil {
return nil, nil, err
}
key, err := si.ec.evalExprList(aux)
if err != nil {
return nil, nil, err
}

rows = append(rows, r.copyAllData())
keys = append(keys, key)
}
return
}

// externalRowSorter implements sort.Interface for a slice of rows
// with an external sort key.
type externalRowSorter struct {
rows []row
keys [][]interface{}
desc []bool // may be nil
}

func (ers externalRowSorter) Len() int { return len(ers.rows) }
func (ers externalRowSorter) Less(i, j int) bool {
return compareValLists(ers.keys[i], ers.keys[j], nil) < 0
return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0
}
func (ers externalRowSorter) Swap(i, j int) {
ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
Expand Down