Skip to content

Commit

Permalink
all: unwrap database.LoginSourcesStore interface (#7694)
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon committed Mar 18, 2024
1 parent 3a5132b commit e634aa6
Show file tree
Hide file tree
Showing 16 changed files with 1,181 additions and 2,055 deletions.
5 changes: 3 additions & 2 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ const (
PAM // 4
DLDAP // 5
GitHub // 6

Mock Type = 999
)

// Name returns the human-readable name for given authentication type.
Expand All @@ -45,8 +47,7 @@ type ErrBadCredentials struct {
// IsErrBadCredentials returns true if the underlying error has the type
// ErrBadCredentials.
func IsErrBadCredentials(err error) bool {
_, ok := errors.Cause(err).(ErrBadCredentials)
return ok
return errors.As(err, &ErrBadCredentials{})
}

func (err ErrBadCredentials) Error() string {
Expand Down
11 changes: 9 additions & 2 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,12 @@ func NewConnection(w logger.Writer) (*gorm.DB, error) {
log.Trace("Auto migrated %q", name)
}

sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
loadedLoginSourceFilesStore, err = loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
if err != nil {
return nil, errors.Wrap(err, "load login source files")
}

// Initialize stores, sorted in alphabetical order.
LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
Notices = NewNoticesStore(db)
Orgs = NewOrgsStore(db)
Perms = NewPermsStore(db)
Expand Down Expand Up @@ -166,3 +165,11 @@ func (db *DB) Actions() *ActionsStore {
func (db *DB) LFS() *LFSStore {
return newLFSStore(db.db)
}

// NOTE: It is not guarded by a mutex because it only gets written during the
// service start.
var loadedLoginSourceFilesStore loginSourceFilesStore

func (db *DB) LoginSources() *LoginSourcesStore {
return newLoginSourcesStore(db.db, loadedLoginSourceFilesStore)
}
3 changes: 1 addition & 2 deletions internal/database/login_source_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ type ErrLoginSourceNotExist struct {
}

func IsErrLoginSourceNotExist(err error) bool {
_, ok := err.(ErrLoginSourceNotExist)
return ok
return errors.As(err, &ErrLoginSourceNotExist{})
}

func (err ErrLoginSourceNotExist) Error() string {
Expand Down
105 changes: 55 additions & 50 deletions internal/database/login_sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,6 @@ import (
"gogs.io/gogs/internal/errutil"
)

// LoginSourcesStore is the persistent interface for login sources.
type LoginSourcesStore interface {
// Create creates a new login source and persist to database. It returns
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error)
// Count returns the total number of login sources.
Count(ctx context.Context) int64
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
// if at least one user is associated with the login source.
DeleteByID(ctx context.Context, id int64) error
// GetByID returns the login source with given ID. It returns
// ErrLoginSourceNotExist when not found.
GetByID(ctx context.Context, id int64) (*LoginSource, error)
// List returns a list of login sources filtered by options.
List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error)
// ResetNonDefault clears default flag for all the other login sources.
ResetNonDefault(ctx context.Context, source *LoginSource) error
// Save persists all values of given login source to database or local file. The
// Updated field is set to current time automatically.
Save(ctx context.Context, t *LoginSource) error
}

var LoginSources LoginSourcesStore

// LoginSource represents an external way for authorizing users.
type LoginSource struct {
ID int64 `gorm:"primaryKey"`
Expand Down Expand Up @@ -88,6 +64,10 @@ func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
return nil
}

type mockProviderConfig struct {
ExternalAccount *auth.ExternalAccount
}

// AfterFind implements the GORM query hook.
func (s *LoginSource) AfterFind(_ *gorm.DB) error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
Expand Down Expand Up @@ -134,6 +114,16 @@ func (s *LoginSource) AfterFind(_ *gorm.DB) error {
}
s.Provider = github.NewProvider(&cfg)

case auth.Mock:
var cfg mockProviderConfig
err := jsoniter.UnmarshalFromString(s.Config, &cfg)
if err != nil {
return err
}
mockProvider := NewMockProvider()
mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
s.Provider = mockProvider

default:
return fmt.Errorf("unrecognized login source type: %v", s.Type)
}
Expand Down Expand Up @@ -180,13 +170,19 @@ func (s *LoginSource) GitHub() *github.Config {
return s.Provider.Config().(*github.Config)
}

var _ LoginSourcesStore = (*loginSourcesStore)(nil)

type loginSourcesStore struct {
*gorm.DB
// LoginSourcesStore is the storage layer for login sources.
type LoginSourcesStore struct {
db *gorm.DB
files loginSourceFilesStore
}

func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
return &LoginSourcesStore{
db: db,
files: files,
}
}

type CreateLoginSourceOptions struct {
Type auth.Type
Name string
Expand All @@ -200,19 +196,20 @@ type ErrLoginSourceAlreadyExist struct {
}

func IsErrLoginSourceAlreadyExist(err error) bool {
_, ok := err.(ErrLoginSourceAlreadyExist)
return ok
return errors.As(err, &ErrLoginSourceAlreadyExist{})
}

func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}

func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
// Create creates a new login source and persists it to the database. It returns
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if err != gorm.ErrRecordNotFound {
} else if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}

Expand All @@ -226,12 +223,13 @@ func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOp
if err != nil {
return nil, err
}
return source, s.WithContext(ctx).Create(source).Error
return source, s.db.WithContext(ctx).Create(source).Error
}

func (s *loginSourcesStore) Count(ctx context.Context) int64 {
// Count returns the total number of login sources.
func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
var count int64
s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(s.files.Len())
}

Expand All @@ -240,31 +238,34 @@ type ErrLoginSourceInUse struct {
}

func IsErrLoginSourceInUse(err error) bool {
_, ok := err.(ErrLoginSourceInUse)
return ok
return errors.As(err, &ErrLoginSourceInUse{})
}

func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}

func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
// if at least one user is associated with the login source.
func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
var count int64
err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}

return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
}

func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
// GetByID returns the login source with given ID. It returns
// ErrLoginSourceNotExist when not found.
func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
source := new(LoginSource)
err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return s.files.GetByID(id)
}
return nil, err
Expand All @@ -277,9 +278,10 @@ type ListLoginSourceOptions struct {
OnlyActivated bool
}

func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
// List returns a list of login sources filtered by options.
func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
var sources []*LoginSource
query := s.WithContext(ctx).Order("id ASC")
query := s.db.WithContext(ctx).Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
Expand All @@ -291,8 +293,9 @@ func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOption
return append(sources, s.files.List(opts)...), nil
}

func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := s.WithContext(ctx).
// ResetNonDefault clears default flag for all the other login sources.
func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := s.db.WithContext(ctx).
Model(new(LoginSource)).
Where("id != ?", dflt.ID).
Updates(map[string]any{"is_default": false}).
Expand All @@ -314,9 +317,11 @@ func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSour
return nil
}

func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
// Save persists all values of given login source to database or local file. The
// Updated field is set to current time automatically.
func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
if source.File == nil {
return s.WithContext(ctx).Save(source).Error
return s.db.WithContext(ctx).Save(source).Error
}

source.File.SetGeneral("name", source.Name)
Expand Down

0 comments on commit e634aa6

Please sign in to comment.