Skip to content

Commit

Permalink
implement logic to handle through relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
Miguel Molina committed Jun 19, 2017
1 parent fa565e8 commit a9e84f6
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 58 deletions.
132 changes: 111 additions & 21 deletions batcher.go
Expand Up @@ -14,6 +14,7 @@ type batchQueryRunner struct {
q Query
oneToOneRels []Relationship
oneToManyRels []Relationship
throughRels []Relationship
db squirrel.DBProxy
builder squirrel.SelectBuilder
total int
Expand All @@ -29,6 +30,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
var (
oneToOneRels []Relationship
oneToManyRels []Relationship
throughRels []Relationship
)

for _, rel := range q.getRelationships() {
Expand All @@ -37,6 +39,8 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
oneToOneRels = append(oneToOneRels, rel)
case OneToMany:
oneToManyRels = append(oneToManyRels, rel)
case Through:
throughRels = append(throughRels, rel)
}
}

Expand All @@ -46,6 +50,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
q: q,
oneToOneRels: oneToOneRels,
oneToManyRels: oneToManyRels,
throughRels: throughRels,
db: db,
builder: builder,
}
Expand Down Expand Up @@ -125,8 +130,14 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
return nil, err
}

if len(records) == 0 {
return nil, nil
}

var ids = make([]interface{}, len(records))
var identType Identifier
for i, r := range records {
identType = r.GetID()
ids[i] = r.GetID().Raw()
}

Expand All @@ -136,63 +147,142 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
return nil, err
}

for _, r := range records {
err := r.SetRelationship(rel.Field, indexedResults[r.GetID().Raw()])
if err != nil {
return nil, err
}
err = setIndexedResults(records, rel, indexedResults)
if err != nil {
return nil, err
}
}

// If the relationship is partial, we can not ensure the results
// in the field reflect the truth of the database.
// In this case, the parent is marked as non-writable.
if rel.Filter != nil {
r.setWritable(false)
}
for _, rel := range r.throughRels {
indexedResults, err := r.getRecordThroughRelationships(ids, rel, identType)
if err != nil {
return nil, err
}

err = setIndexedResults(records, rel, indexedResults)
if err != nil {
return nil, err
}
}

return records, nil
}

func setIndexedResults(records []Record, rel Relationship, indexedResults indexedRecords) error {
for _, r := range records {
err := r.SetRelationship(rel.Field, indexedResults[r.GetID().Raw()])
if err != nil {
return err
}

// If the relationship is partial, we can not ensure the results
// in the field reflect the truth of the database.
// In this case, the parent is marked as non-writable.
if rel.Filter != nil {
r.setWritable(false)
}
}

return nil
}

type indexedRecords map[interface{}][]Record

func (r *batchQueryRunner) getRecordRelationships(ids []interface{}, rel Relationship) (indexedRecords, error) {
fk, ok := r.schema.ForeignKey(rel.Field)
if !ok {
return nil, fmt.Errorf("kallax: cannot find foreign key on field %s for table %s", rel.Field, r.schema.Table())
return nil, fmt.Errorf("kallax: cannot find foreign key on field %s of table %s", rel.Field, r.schema.Table())
}

filter := In(fk, ids...)
if rel.Filter != nil {
And(rel.Filter, filter)
} else {
rel.Filter = filter
filter = And(rel.Filter, filter)
}

q := NewBaseQuery(rel.Schema)
q.Where(rel.Filter)
q.Where(filter)
cols, builder := q.compile()
rows, err := builder.RunWith(r.db).Query()
if err != nil {
return nil, err
}

return indexedResultsFromRows(rows, cols, rel.Schema, fk, nil)
}

func (r *batchQueryRunner) getRecordThroughRelationships(ids []interface{}, rel Relationship, identType Identifier) (indexedRecords, error) {
lfk, rfk, ok := r.schema.ForeignKeys(rel.Field)
if !ok {
return nil, fmt.Errorf("kallax: cannot find foreign keys for through relationship on field %s of table %s", rel.Field, r.schema.Table())
}

filter := In(r.schema.ID(), ids...)
if rel.Filter != nil {
filter = And(rel.Filter, filter)
}

if rel.IntermediateFilter != nil {
filter = And(rel.IntermediateFilter, filter)
}

q := NewBaseQuery(rel.Schema)
lschema := r.schema.WithAlias(rel.Schema.Alias())
intSchema := rel.IntermediateSchema.WithAlias(rel.Schema.Alias())
q.joinThrough(lschema, intSchema, rel.Schema, lfk, rfk)
q.Where(filter)
cols, builder := q.compile()
// manually add the extra column to also select the parent id
builder = builder.Column(lschema.ID().QualifiedName(lschema))
rows, err := builder.RunWith(r.db).Query()
if err != nil {
return nil, err
}

// we need to pass a new pointer of the parent identifier type so the
// resultset can fill it and we can know to which record it belongs when
// indexing by parent id.
return indexedResultsFromRows(rows, cols, rel.Schema, rfk, identType.newPtr())
}

// indexedResultsFromRows returns the results in the given rows indexed by the
// parent id. In the case of many to many relationships, the record odes not
// have a specific field with the ID of the parent to index by it,
// that's why parentIDPtr is passed for these cases. parentIDPtr is a pointer
// to an ID of the type required by the parent to be filled by the result set.
func indexedResultsFromRows(rows *sql.Rows, cols []string, schema Schema, fk SchemaField, parentIDPtr interface{}) (indexedRecords, error) {
relRs := NewResultSet(rows, false, nil, cols...)
var indexedResults = make(indexedRecords)
for relRs.Next() {
rec, err := relRs.Get(rel.Schema)
if err != nil {
return nil, err
var (
rec Record
err error
)

if parentIDPtr != nil {
rec, err = relRs.customGet(schema, parentIDPtr)
} else {
rec, err = relRs.Get(schema)
}

val, err := rec.Value(fk.String())
if err != nil {
return nil, err
}

rec.setPersisted()
rec.setWritable(true)
id := val.(Identifier).Raw()

var id interface{}
if parentIDPtr != nil {
id = parentIDPtr.(Identifier).Raw()
} else {
val, err := rec.Value(fk.String())
if err != nil {
return nil, err
}

id = val.(Identifier).Raw()
}

indexedResults[id] = append(indexedResults[id], rec)
}

Expand Down

0 comments on commit a9e84f6

Please sign in to comment.