Skip to content

Commit

Permalink
Oauth unlink social accounts (#253)
Browse files Browse the repository at this point in the history
This PR is adding the ability of share oauth account information and delete them.

- Adding email column to oauth_accounts table in order to facilitate users identify their accounts.
- Updating GET /account/{id} to return oauth_providers containing OAuth information.
- Adding new public endpoint DELETE /oauth/{provider} allowing user unlink OAuth accounts.
- Adding new public endpoint GET /oauth/info returning oauth information for linked accounts.
- Adding new private endpoint DELETE /account/{id}/oauth allowing unlink OAuth accounts for the given user.

Fixes #252
  • Loading branch information
diegosperes committed Apr 13, 2024
1 parent 20fe8d8 commit 9c9929a
Show file tree
Hide file tree
Showing 32 changed files with 850 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## HEAD

### Added

* Public and private APIs for oauth account visibility and removal - requires migration to record user email on oauth accounts (#253)

## 1.19.0

### Added
Expand Down
4 changes: 3 additions & 1 deletion app/data/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ type AccountStore interface {
Find(id int) (*models.Account, error)
FindByUsername(u string) (*models.Account, error)
FindByOauthAccount(p string, pid string) (*models.Account, error)
AddOauthAccount(id int, p string, pid string, tok string) error
AddOauthAccount(id int, p string, pid string, email string, tok string) error
UpdateOauthAccount(id int, p string, email string) (bool, error)
DeleteOauthAccount(id int, p string) (bool, error)
GetOauthAccounts(id int) ([]*models.OauthAccount, error)
Archive(id int) (bool, error)
Lock(id int) (bool, error)
Expand Down
29 changes: 28 additions & 1 deletion app/data/mock/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (s *accountStore) Create(u string, p []byte) (*models.Account, error) {
return dupAccount(acc), nil
}

func (s *accountStore) AddOauthAccount(accountID int, provider string, providerID string, tok string) error {
func (s *accountStore) AddOauthAccount(accountID int, provider, providerID, email, tok string) error {
p := provider + "|" + providerID
if s.idByOauthID[p] != 0 {
return Error{ErrNotUnique}
Expand All @@ -107,6 +107,7 @@ func (s *accountStore) AddOauthAccount(accountID int, provider string, providerI

now := time.Now()
oauthAccount := &models.OauthAccount{
Email: email,
AccountID: accountID,
Provider: provider,
ProviderID: providerID,
Expand All @@ -124,6 +125,32 @@ func (s *accountStore) GetOauthAccounts(accountID int) ([]*models.OauthAccount,
return s.oauthAccountsByID[accountID], nil
}

func (s *accountStore) UpdateOauthAccount(accountID int, provider, email string) (bool, error) {
oauthAccounts := s.oauthAccountsByID[accountID]

for i, oauthAccount := range oauthAccounts {
if oauthAccount.Provider == provider {
s.oauthAccountsByID[accountID][i].Email = email
return true, nil
}
}

return false, nil
}

func (s *accountStore) DeleteOauthAccount(accountID int, provider string) (bool, error) {
oauthAccounts := s.oauthAccountsByID[accountID]

for i, oauthAccount := range oauthAccounts {
if oauthAccount.Provider == provider {
s.oauthAccountsByID[accountID] = append(oauthAccounts[:i], oauthAccounts[i+1:]...)
return true, nil
}
}

return false, nil
}

func (s *accountStore) Archive(id int) (bool, error) {
account := s.accountsByID[id]
if account == nil {
Expand Down
25 changes: 22 additions & 3 deletions app/data/mysql/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ func (db *AccountStore) Create(u string, p []byte) (*models.Account, error) {
return account, nil
}

func (db *AccountStore) AddOauthAccount(accountID int, provider string, providerID string, accessToken string) error {
func (db *AccountStore) AddOauthAccount(accountID int, provider, providerID, email, accessToken string) error {
now := time.Now()

_, err := sqlx.NamedExec(db, `
INSERT INTO oauth_accounts (account_id, provider, provider_id, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :access_token, :created_at, :updated_at)
INSERT INTO oauth_accounts (account_id, provider, provider_id, email, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :email, :access_token, :created_at, :updated_at)
`, map[string]interface{}{
"email": email,
"account_id": accountID,
"provider": provider,
"provider_id": providerID,
Expand All @@ -99,6 +100,24 @@ func (db *AccountStore) GetOauthAccounts(accountID int) ([]*models.OauthAccount,
return accounts, err
}

func (db *AccountStore) UpdateOauthAccount(accountId int, provider, email string) (bool, error) {
result, err := db.Exec("UDPATE oauth_accounts SET email = ? WHERE account_id = ? AND provider = ?", email, accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) DeleteOauthAccount(accountId int, provider string) (bool, error) {
result, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = ? AND provider = ?", accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) Archive(id int) (bool, error) {
_, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = ?", id)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions app/data/mysql/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func MigrateDB(db *sqlx.DB) error {
createOauthAccounts,
createAccountLastLoginAtField,
createAccountTOTPFields,
addOauthAccountEmail,
}
for _, m := range migrations {
if err := m(db); err != nil {
Expand Down Expand Up @@ -61,6 +62,18 @@ func createOauthAccounts(db *sqlx.DB) error {
return err
}

func addOauthAccountEmail(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE oauth_accounts ADD COLUMN email VARCHAR(255) DEFAULT NULL;
`)
if mysqlError, ok := err.(*mysql.MySQLError); ok {
if mysqlError.Number == 1060 { // 1060 = Duplicate column name
err = nil
}
}
return err
}

func createAccountLastLoginAtField(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE accounts ADD last_login_at DATETIME DEFAULT NULL
Expand Down
25 changes: 22 additions & 3 deletions app/data/postgres/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ func (db *AccountStore) Create(u string, p []byte) (*models.Account, error) {
return account, nil
}

func (db *AccountStore) AddOauthAccount(accountID int, provider string, providerID string, accessToken string) error {
func (db *AccountStore) AddOauthAccount(accountID int, provider, providerID, email, accessToken string) error {
now := time.Now()

_, err := sqlx.NamedExec(db, `
INSERT INTO oauth_accounts (account_id, provider, provider_id, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :access_token, :created_at, :updated_at)
INSERT INTO oauth_accounts (account_id, provider, provider_id, email, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :email, :access_token, :created_at, :updated_at)
`, map[string]interface{}{
"email": email,
"account_id": accountID,
"provider": provider,
"provider_id": providerID,
Expand All @@ -111,6 +112,24 @@ func (db *AccountStore) GetOauthAccounts(accountID int) ([]*models.OauthAccount,
return accounts, err
}

func (db *AccountStore) UpdateOauthAccount(accountId int, provider, email string) (bool, error) {
result, err := db.Exec("UPDATE oauth_accounts SET email = $1 WHERE account_id = $2 AND provider = $3", email, accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) DeleteOauthAccount(accountId int, provider string) (bool, error) {
result, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = $1 AND provider = $2", accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) Archive(id int) (bool, error) {
_, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = $1", id)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions app/data/postgres/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func MigrateDB(db *sqlx.DB) error {
createAccountLastLoginAtField,
caseInsensitiveUsername,
createAccountTOTPFields,
addOauthAccountEmail,
}
for _, m := range migrations {
if err := m(db); err != nil {
Expand Down Expand Up @@ -56,6 +57,13 @@ func createOauthAccounts(db *sqlx.DB) error {
return err
}

func addOauthAccountEmail(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE oauth_accounts ADD COLUMN IF NOT EXISTS email VARCHAR(255) DEFAULT NULL;
`)
return err
}

func createAccountLastLoginAtField(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS last_login_at timestamptz DEFAULT NULL
Expand Down
25 changes: 22 additions & 3 deletions app/data/sqlite3/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ func (db *AccountStore) Create(u string, p []byte) (*models.Account, error) {
return account, nil
}

func (db *AccountStore) AddOauthAccount(accountID int, provider string, providerID string, accessToken string) error {
func (db *AccountStore) AddOauthAccount(accountID int, provider, providerID, email, accessToken string) error {
now := time.Now()

_, err := sqlx.NamedExec(db, `
INSERT INTO oauth_accounts (account_id, provider, provider_id, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :access_token, :created_at, :updated_at)
INSERT INTO oauth_accounts (account_id, provider, provider_id, email, access_token, created_at, updated_at)
VALUES (:account_id, :provider, :provider_id, :email, :access_token, :created_at, :updated_at)
`, map[string]interface{}{
"email": email,
"account_id": accountID,
"provider": provider,
"provider_id": providerID,
Expand All @@ -99,6 +100,24 @@ func (db *AccountStore) GetOauthAccounts(accountID int) ([]*models.OauthAccount,
return accounts, err
}

func (db *AccountStore) UpdateOauthAccount(accountId int, provider, email string) (bool, error) {
result, err := db.Exec("UPDATE oauth_accounts SET email = ? WHERE account_id = ? AND provider = ?", email, accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) DeleteOauthAccount(accountId int, provider string) (bool, error) {
result, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = ? AND provider = ?", accountId, provider)
if err != nil {
return false, err
}

return ok(result, err)
}

func (db *AccountStore) Archive(id int) (bool, error) {
_, err := db.Exec("DELETE FROM oauth_accounts WHERE account_id = ?", id)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions app/data/sqlite3/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func MigrateDB(db *sqlx.DB) error {
createAccountLastLoginAtField,
caseInsensitiveUsername,
createAccountTOTPFields,
addOauthAccountEmail,
}
for _, m := range migrations {
if err := m(db); err != nil {
Expand Down Expand Up @@ -96,6 +97,16 @@ func createOauthAccounts(db *sqlx.DB) error {
return err
}

func addOauthAccountEmail(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE oauth_accounts ADD COLUMN email VARCHAR(255) DEFAULT NULL;
`)
if isDuplicateError(err) {
return nil
}
return err
}

func createAccountLastLoginAtField(db *sqlx.DB) error {
_, err := db.Exec(`
ALTER TABLE accounts ADD last_login_at DATETIME
Expand Down
8 changes: 4 additions & 4 deletions app/data/testers/account_store_testers.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func testArchive(t *testing.T, store data.AccountStore) {
func testArchiveWithOauth(t *testing.T, store data.AccountStore) {
account, err := store.Create("authn@keratin.tech", []byte("password"))
require.NoError(t, err)
err = store.AddOauthAccount(account.ID, "PROVIDER", "PROVIDERID", "token")
err = store.AddOauthAccount(account.ID, "PROVIDER", "PROVIDERID", "email", "token")
require.NoError(t, err)

ok, err := store.Archive(account.ID)
Expand Down Expand Up @@ -261,7 +261,7 @@ func testAddOauthAccount(t *testing.T, store data.AccountStore) {

account, err := store.Create("authn@keratin.tech", []byte("password"))
assert.NoError(t, err)
err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID", "TOKEN")
err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID", "email", "TOKEN")
assert.NoError(t, err)

found, err = store.GetOauthAccounts(account.ID)
Expand All @@ -274,7 +274,7 @@ func testAddOauthAccount(t *testing.T, store data.AccountStore) {
assert.NotEmpty(t, found[0].CreatedAt)
assert.NotEmpty(t, found[0].UpdatedAt)

err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID2", "TOKEN")
err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID2", "email", "TOKEN")
if err == nil || !data.IsUniquenessError(err) {
t.Errorf("expected uniqueness error, got %T %v", err, err)
}
Expand All @@ -290,7 +290,7 @@ func testFindByOauthAccount(t *testing.T, store data.AccountStore) {

account, err := store.Create("authn@keratin.tech", []byte("password"))
require.NoError(t, err)
err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID", "TOKEN")
err = store.AddOauthAccount(account.ID, "OAUTHPROVIDER", "PROVIDERID", "email", "TOKEN")
require.NoError(t, err)

found, err = store.FindByOauthAccount("unknown", "PROVIDERID")
Expand Down
40 changes: 36 additions & 4 deletions app/models/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"database/sql"
"encoding/json"
"time"
)

Expand All @@ -13,10 +14,11 @@ type Account struct {
RequireNewPassword bool `db:"require_new_password"`
PasswordChangedAt time.Time `db:"password_changed_at"`
TOTPSecret sql.NullString `db:"totp_secret"`
LastLoginAt *time.Time `db:"last_login_at"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
DeletedAt *time.Time `db:"deleted_at"`
OauthAccounts []*OauthAccount
LastLoginAt *time.Time `db:"last_login_at"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
DeletedAt *time.Time `db:"deleted_at"`
}

func (a Account) Archived() bool {
Expand All @@ -30,3 +32,33 @@ func (a Account) TOTPEnabled() bool {
}
return false
}

func (a Account) MarshalJSON() ([]byte, error) {
formattedLastLogin := ""
if a.LastLoginAt != nil {
formattedLastLogin = a.LastLoginAt.Format(time.RFC3339)
}

formattedPasswordChangedAt := ""
if !a.PasswordChangedAt.IsZero() {
formattedPasswordChangedAt = a.PasswordChangedAt.Format(time.RFC3339)
}

return json.Marshal(struct {
ID int `json:"id"`
Username string `json:"username"`
OauthAccounts []*OauthAccount `json:"oauth_accounts"`
LastLoginAt string `json:"last_login_at"`
PasswordChangedAt string `json:"password_changed_at"`
Locked bool `json:"locked"`
Deleted bool `json:"deleted"`
}{
ID: a.ID,
Username: a.Username,
OauthAccounts: a.OauthAccounts,
LastLoginAt: formattedLastLogin,
PasswordChangedAt: formattedPasswordChangedAt,
Locked: a.Locked,
Deleted: a.DeletedAt != nil,
})
}
18 changes: 17 additions & 1 deletion app/models/oauth_account.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
package models

import "time"
import (
"encoding/json"
"time"
)

type OauthAccount struct {
ID int
AccountID int `db:"account_id"`
Provider string
ProviderID string `db:"provider_id"`
Email string `db:"email"`
AccessToken string `db:"access_token"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (o OauthAccount) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_account_id"`
Email string `json:"email"`
}{
Provider: o.Provider,
ProviderID: o.ProviderID,
Email: o.Email,
})
}

0 comments on commit 9c9929a

Please sign in to comment.