Skip to content

Commit

Permalink
SSO: fix settings merge for SAML fields (#86970)
Browse files Browse the repository at this point in the history
* fix sso settings merge for saml fields

* change func name
  • Loading branch information
dmihai committed Apr 30, 2024
1 parent 125ac18 commit 76d94b3
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 19 deletions.
18 changes: 17 additions & 1 deletion pkg/services/ssosettings/ssosettingsimpl/service.go
Expand Up @@ -470,7 +470,9 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any

for k, v := range systemSettings {
if _, ok := settings[k]; !ok {
settings[k] = v
if isMergingAllowed(k) {
settings[k] = v
}
} else if isURL(k) && isEmptyString(settings[k]) {
// Overwrite all URL settings from the DB containing an empty string with their value
// from the system settings. This fixes an issue with empty auth_url, api_url and token_url
Expand All @@ -483,6 +485,20 @@ func mergeSettings(storedSettings, systemSettings map[string]any) map[string]any
return settings
}

// isMergingAllowed returns true if the field provided can be merged from the system settings.
// It won't allow SAML fields that are part of a group of settings to be merged from system settings
// because the DB settings already contain one valid setting from each group.
func isMergingAllowed(fieldName string) bool {
forbiddenMergePatterns := []string{"certificate", "private_key", "idp_metadata"}

for _, v := range forbiddenMergePatterns {
if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) {
return false
}
}
return true
}

// mergeSecrets returns a new map with the current value for secrets that have not been updated
func mergeSecrets(settings map[string]any, storedSettings map[string]any) (map[string]any, error) {
settingsWithSecrets := map[string]any{}
Expand Down
85 changes: 67 additions & 18 deletions pkg/services/ssosettings/ssosettingsimpl/service_test.go
Expand Up @@ -40,13 +40,15 @@ func TestService_GetForProvider(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
setup func(env testEnv)
want *models.SSOSettings
wantErr bool
name string
provider string
setup func(env testEnv)
want *models.SSOSettings
wantErr bool
}{
{
name: "should return successfully",
name: "should return successfully",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "github",
Expand All @@ -72,13 +74,15 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: false,
},
{
name: "should return error if store returns an error different than not found",
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
want: nil,
wantErr: true,
name: "should return error if store returns an error different than not found",
provider: "github",
setup: func(env testEnv) { env.store.ExpectedError = fmt.Errorf("error") },
want: nil,
wantErr: true,
},
{
name: "should fallback to the system settings if store returns not found",
name: "should fallback to the system settings if store returns not found",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedError = ssosettings.ErrNotFound
env.fallbackStrategy.ExpectedIsMatch = true
Expand All @@ -99,7 +103,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: false,
},
{
name: "should return error if the fallback strategy was not found",
name: "should return error if the fallback strategy was not found",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedError = ssosettings.ErrNotFound
env.fallbackStrategy.ExpectedIsMatch = false
Expand All @@ -108,7 +113,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: true,
},
{
name: "should return error if fallback strategy returns error",
name: "should return error if fallback strategy returns error",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedError = ssosettings.ErrNotFound
env.fallbackStrategy.ExpectedIsMatch = true
Expand All @@ -118,7 +124,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: true,
},
{
name: "should decrypt secrets if data is coming from store",
name: "should decrypt secrets if data is coming from store",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "github",
Expand Down Expand Up @@ -152,7 +159,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: false,
},
{
name: "should not decrypt secrets if data is coming from the fallback strategy",
name: "should not decrypt secrets if data is coming from the fallback strategy",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedError = ssosettings.ErrNotFound
env.fallbackStrategy.ExpectedIsMatch = true
Expand All @@ -176,7 +184,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: false,
},
{
name: "should return an error if the data in the store is invalid",
name: "should return an error if the data in the store is invalid",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "github",
Expand All @@ -196,7 +205,8 @@ func TestService_GetForProvider(t *testing.T) {
wantErr: true,
},
{
name: "correctly merge the DB and system settings",
name: "correctly merge URLs from the DB and system settings",
provider: "github",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "github",
Expand Down Expand Up @@ -231,6 +241,45 @@ func TestService_GetForProvider(t *testing.T) {
},
wantErr: false,
},
{
name: "correctly merge group of settings for SAML",
provider: "saml",
setup: func(env testEnv) {
env.store.ExpectedSSOSetting = &models.SSOSettings{
Provider: "saml",
Settings: map[string]any{
"certificate": base64.RawStdEncoding.EncodeToString([]byte("valid-certificate")),
"private_key_path": base64.RawStdEncoding.EncodeToString([]byte("path/to/private/key")),
"idp_metadata_url": "https://idp-metadata.com",
},
Source: models.DB,
}
env.fallbackStrategy.ExpectedIsMatch = true
env.fallbackStrategy.ExpectedConfigs = map[string]map[string]any{
"saml": {
"name": "test-settings",
"certificate_path": "path/to/certificate",
"private_key": "this-is-a-valid-private-key",
"idp_metadata_path": "path/to/metadata",
"max_issue_delay": "1h",
},
}
env.secrets.On("Decrypt", mock.Anything, []byte("valid-certificate"), mock.Anything).Return([]byte("decrypted-valid-certificate"), nil).Once()
env.secrets.On("Decrypt", mock.Anything, []byte("path/to/private/key"), mock.Anything).Return([]byte("decrypted/path/to/private/key"), nil).Once()
},
want: &models.SSOSettings{
Provider: "saml",
Settings: map[string]any{
"name": "test-settings",
"certificate": "decrypted-valid-certificate",
"private_key_path": "decrypted/path/to/private/key",
"idp_metadata_url": "https://idp-metadata.com",
"max_issue_delay": "1h",
},
Source: models.DB,
},
wantErr: false,
},
}

for _, tc := range testCases {
Expand All @@ -241,12 +290,12 @@ func TestService_GetForProvider(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

env := setupTestEnv(t, false, false, false)
env := setupTestEnv(t, true, false, true)
if tc.setup != nil {
tc.setup(env)
}

actual, err := env.service.GetForProvider(context.Background(), "github")
actual, err := env.service.GetForProvider(context.Background(), tc.provider)

if tc.wantErr {
require.Error(t, err)
Expand Down

0 comments on commit 76d94b3

Please sign in to comment.