Skip to content

Commit

Permalink
Merge pull request #656 from upper/allow-select-from-db-result
Browse files Browse the repository at this point in the history
allow using db.Result as subquery
  • Loading branch information
xiam committed Jun 17, 2022
2 parents bb6a386 + 9709cdd commit b5aff2b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
14 changes: 7 additions & 7 deletions internal/sqladapter/result.go
Expand Up @@ -213,7 +213,7 @@ func (r *Result) Select(fields ...interface{}) db.Result {

// String satisfies fmt.Stringer
func (r *Result) String() string {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
panic(err.Error())
}
Expand All @@ -222,7 +222,7 @@ func (r *Result) String() string {

// All dumps all Results into a pointer to an slice of structs or maps.
func (r *Result) All(dst interface{}) error {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return err
Expand All @@ -235,7 +235,7 @@ func (r *Result) All(dst interface{}) error {
// One fetches only one Result from the set.
func (r *Result) One(dst interface{}) error {
one := r.Limit(1).(*Result)
query, err := one.buildPaginator()
query, err := one.Paginator()
if err != nil {
r.setErr(err)
return err
Expand All @@ -251,7 +251,7 @@ func (r *Result) Next(dst interface{}) bool {
defer r.iterMu.Unlock()

if r.iter == nil {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return false
Expand Down Expand Up @@ -309,7 +309,7 @@ func (r *Result) Update(values interface{}) error {
}

func (r *Result) TotalPages() (uint, error) {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
Expand All @@ -325,7 +325,7 @@ func (r *Result) TotalPages() (uint, error) {
}

func (r *Result) TotalEntries() (uint64, error) {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
Expand Down Expand Up @@ -391,7 +391,7 @@ func (r *Result) Count() (uint64, error) {
return counter.Count, nil
}

func (r *Result) buildPaginator() (db.Paginator, error) {
func (r *Result) Paginator() (db.Paginator, error) {
if err := r.Err(); err != nil {
return nil, err
}
Expand Down
18 changes: 16 additions & 2 deletions internal/sqlbuilder/builder.go
Expand Up @@ -51,7 +51,11 @@ var defaultMapOptions = MapOptions{
IncludeNil: false,
}

type compilable interface {
type hasPaginator interface {
Paginator() (db.Paginator, error)
}

type isCompilable interface {
Compile() (string, error)
Arguments() []interface{}
}
Expand Down Expand Up @@ -347,7 +351,17 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err

for i := range columns {
switch v := columns[i].(type) {
case compilable:
case hasPaginator:
p, err := v.Paginator()
if err != nil {
return nil, nil, err
}

q, a := Preprocess(p.String(), p.Arguments())

f[i] = exql.RawValue("(" + q + ")")
args = append(args, a...)
case isCompilable:
c, err := v.Compile()
if err != nil {
return nil, nil, err
Expand Down
8 changes: 7 additions & 1 deletion internal/sqlbuilder/convert.go
Expand Up @@ -122,7 +122,13 @@ func preprocessFn(arg interface{}) (string, []interface{}) {
switch t := arg.(type) {
case *adapter.RawExpr:
return Preprocess(t.Raw(), t.Arguments())
case compilable:
case hasPaginator:
p, err := t.Paginator()
if err == nil {
return `(` + p.String() + `)`, p.Arguments()
}
panic(err.Error())
case isCompilable:
c, err := t.Compile()
if err == nil {
return `(` + c + `)`, t.Arguments()
Expand Down
31 changes: 31 additions & 0 deletions internal/testsuite/sql_suite.go
Expand Up @@ -1898,3 +1898,34 @@ func (s *SQLTestSuite) Test_Issue565() {
s.Error(err)
s.Zero(result.Name)
}

func (s *SQLTestSuite) TestSelectFromSubquery() {
sess := s.Session()

{
var artists []artistType
q := sess.SQL().SelectFrom(
sess.SQL().SelectFrom("artist").Where(db.Cond{
"name": db.IsNotNull(),
}),
).As("_q")
err := q.All(&artists)
s.NoError(err)

s.NotZero(len(artists))
}

{
var artists []artistType
q := sess.SQL().SelectFrom(
sess.Collection("artist").Find(db.Cond{
"name": db.IsNotNull(),
}),
).As("_q")
err := q.All(&artists)
s.NoError(err)

s.NotZero(len(artists))
}

}

0 comments on commit b5aff2b

Please sign in to comment.