Skip to content

Commit

Permalink
Merge pull request TheThingsNetwork#1752
Browse files Browse the repository at this point in the history
OAuth Clients outside Tenant scope
  • Loading branch information
htdvisser committed Oct 3, 2019
2 parents 1ae2266 + f9e9a4a commit 6bd01ac
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pkg/identityserver/store/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Client struct {
Model
SoftDelete

TenantID string `gorm:"unique_index:client_id_index;type:VARCHAR(36)"`
TenantID *string `gorm:"unique_index:client_id_index;type:VARCHAR(36)"`

// BEGIN common fields
ClientID string `gorm:"unique_index:client_id_index;type:VARCHAR(36);not null"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/identityserver/store/client_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (s *clientStore) GetClient(ctx context.Context, id *ttnpb.ClientIdentifiers
var cliModel Client
if err := query.First(&cliModel).Error; err != nil {
if gorm.IsRecordNotFoundError(err) {
return nil, errNotFoundForID(id)
return s.getClientWithoutTenant(ctx, id, fieldMask)
}
return nil, err
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/identityserver/store/client_store.tti.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright © 2019 The Things Industries B.V.

package store

import (
"context"
"net/url"
"strings"

"github.com/gogo/protobuf/types"
"github.com/jinzhu/gorm"
"go.thethings.network/lorawan-stack/pkg/license"
"go.thethings.network/lorawan-stack/pkg/tenant"
"go.thethings.network/lorawan-stack/pkg/ttipb"
"go.thethings.network/lorawan-stack/pkg/ttnpb"
)

func (s *clientStore) getClientWithoutTenant(ctx context.Context, id *ttnpb.ClientIdentifiers, fieldMask *types.FieldMask) (*ttnpb.Client, error) {
tenantID := tenant.FromContext(ctx).TenantID
if license.RequireMultiTenancy(ctx) != nil || tenantID == "" {
return nil, errNotFoundForID(id)
}
query := s.query(tenant.NewContext(ctx, ttipb.TenantIdentifiers{}), Client{}, withClientID(id.GetClientID()))
query = selectClientFields(ctx, query, fieldMask)
var cliModel Client
if err := query.First(&cliModel).Error; err != nil {
if gorm.IsRecordNotFoundError(err) {
return nil, errNotFoundForID(id)
}
return nil, err
}
cliProto := &ttnpb.Client{}
cliModel.toPB(cliProto, fieldMask)

// Add tenant ID as prefix in Redirect URIs:
if fieldPaths := fieldMask.GetPaths(); len(fieldPaths) > 0 && ttnpb.HasAnyField(fieldPaths, "redirect_uris") {
var tenantRedirectURIs []string
for _, redirectURI := range cliProto.RedirectURIs {
if !strings.Contains(redirectURI, "://") {
continue
}
if uri, err := url.Parse(redirectURI); err == nil {
uri.Host = tenantID + "." + uri.Host
tenantRedirectURIs = append(tenantRedirectURIs, uri.String())
}
}
cliProto.RedirectURIs = append(cliProto.RedirectURIs, tenantRedirectURIs...)
}
return cliProto, nil
}
4 changes: 3 additions & 1 deletion pkg/identityserver/store/model_context.tti.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ func (app *Application) SetContext(ctx context.Context) {

// SetContext needs to be called before creating models.
func (cli *Client) SetContext(ctx context.Context) {
cli.TenantID = tenant.FromContext(ctx).TenantID
if tenantID := tenant.FromContext(ctx).TenantID; tenantID != "" {
cli.TenantID = &tenantID
}
cli.Model.SetContext(ctx)
}

Expand Down
9 changes: 7 additions & 2 deletions pkg/identityserver/store/scope.tti.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"

"github.com/jinzhu/gorm"
"go.thethings.network/lorawan-stack/pkg/license"
"go.thethings.network/lorawan-stack/pkg/tenant"
)

Expand All @@ -20,10 +21,14 @@ func init() {
}
if _, ok := db.Value.(interface{ _isMultiTenant() }); ok {
table := db.NewScope(db.Value).TableName()
tenantID := tenant.FromContext(ctx).TenantID
if table == "users" || table == "organizations" {
return db.Where("accounts.tenant_id = ?", tenant.FromContext(ctx).TenantID)
return db.Where("accounts.tenant_id = ?", tenantID)
}
return db.Where(fmt.Sprintf("%s.tenant_id = ?", table), tenant.FromContext(ctx).TenantID)
if table == "clients" && tenantID == "" && license.RequireMultiTenancy(ctx) == nil {
return db.Where(fmt.Sprintf("%s.tenant_id IS NULL", table))
}
return db.Where(fmt.Sprintf("%s.tenant_id = ?", table), tenantID)
}
return db
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/identityserver/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *store) findEntity(ctx context.Context, entityID ttnpb.Identifiers, fiel
}
if err := query.First(model).Error; err != nil {
if gorm.IsRecordNotFoundError(err) {
return nil, errNotFoundForID(entityID)
return s.findEntityWithoutTenant(ctx, entityID, fields...)
}
return nil, convertError(err)
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/identityserver/store/store.tti.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright © 2019 The Things Industries B.V.

package store

import (
"context"

"github.com/jinzhu/gorm"
"go.thethings.network/lorawan-stack/pkg/license"
"go.thethings.network/lorawan-stack/pkg/tenant"
"go.thethings.network/lorawan-stack/pkg/ttipb"
"go.thethings.network/lorawan-stack/pkg/ttnpb"
)

func (s *store) findEntityWithoutTenant(ctx context.Context, entityID ttnpb.Identifiers, fields ...string) (modelInterface, error) {
tenantID := tenant.FromContext(ctx).TenantID
if license.RequireMultiTenancy(ctx) != nil || tenantID == "" {
return nil, errNotFoundForID(entityID)
}

var model modelInterface
switch entityID.EntityType() {
case "client":
model = &Client{}
default:
return nil, errNotFoundForID(entityID)
}

query := s.query(tenant.NewContext(ctx, ttipb.TenantIdentifiers{}), model, withID(entityID))
if len(fields) == 1 && fields[0] == "id" {
fields[0] = s.DB.NewScope(model).TableName() + ".id"
}
if len(fields) > 0 {
query = query.Select(fields)
}
if err := query.First(model).Error; err != nil {
if gorm.IsRecordNotFoundError(err) {
return nil, errNotFoundForID(entityID)
}
return nil, convertError(err)
}
return model, nil
}
3 changes: 2 additions & 1 deletion pkg/web/oauthclient/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ func (oc *OAuthClient) HandleCallback(c echo.Context) error {
return err
}

return c.Redirect(http.StatusFound, oc.config.RootURL+stateCookie.Next)
config := oc.configFromContext(c.Request().Context())
return c.Redirect(http.StatusFound, config.RootURL+stateCookie.Next)
}
7 changes: 4 additions & 3 deletions pkg/web/oauthclient/oauthclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ func (oc *OAuthClient) configFromContext(ctx context.Context) *Config {
if config, ok := ctx.Value(ctxKey).(*Config); ok {
return config
}
return &oc.config
config := oc.config.Apply(ctx)
return &config
}

func (oc *OAuthClient) oauth(c echo.Context) *oauth2.Config {
config := oc.configFromContext(c.Request().Context())

authorizeURL := config.AuthorizeURL
redirectURL := fmt.Sprintf("%s/oauth/callback", strings.TrimSuffix(oc.config.RootURL, "/"))
if oauthRootURL, err := url.Parse(oc.config.RootURL); err == nil {
redirectURL := fmt.Sprintf("%s/oauth/callback", strings.TrimSuffix(config.RootURL, "/"))
if oauthRootURL, err := url.Parse(config.RootURL); err == nil {
rootURL := (&url.URL{Scheme: oauthRootURL.Scheme, Host: oauthRootURL.Host}).String()
if strings.HasPrefix(authorizeURL, rootURL) {
authorizeURL = strings.TrimPrefix(authorizeURL, rootURL)
Expand Down

0 comments on commit 6bd01ac

Please sign in to comment.