diff --git a/cli/azd/.vscode/cspell-azd-dictionary.txt b/cli/azd/.vscode/cspell-azd-dictionary.txt index a27619f0f6..39547d4395 100644 --- a/cli/azd/.vscode/cspell-azd-dictionary.txt +++ b/cli/azd/.vscode/cspell-azd-dictionary.txt @@ -123,6 +123,7 @@ pyproject pyvenv reauthentication relogin +remarshal restoreapp retriable rzip diff --git a/cli/azd/pkg/auth/cache.go b/cli/azd/pkg/auth/cache.go index 3ff4ec82ff..4f08a7ce2f 100644 --- a/cli/azd/pkg/auth/cache.go +++ b/cli/azd/pkg/auth/cache.go @@ -5,11 +5,24 @@ package auth import ( "context" + "encoding/json" "errors" + "log" + "strings" + "unicode" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" ) +// Known entries from msal cache contract. This is not an exhaustive list. +var contractFields = []string{ + "AccessToken", + "RefreshToken", + "IdToken", + "Account", + "AppMetadata", +} + // The MSAL cache key for the current user. The stored MSAL cached data contains // all accounts with stored credentials, across all tenants. // Currently, the underlying MSAL cache data is represented as [Contract] inside the library. @@ -36,6 +49,46 @@ func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler, return err } + // In msal v1.0, keys were stored with mixed casing; in v1.1., it was changed to lower case. + // This handles upgrades where we have a v1.0 cache, and we need to convert it to v1.1, + // by normalizing the appropriate key entries. + c := map[string]json.RawMessage{} + if err = json.Unmarshal(val, &c); err == nil { + for _, contractKey := range contractFields { + if _, found := c[contractKey]; found { + msg := []byte(c[contractKey]) + inner := map[string]json.RawMessage{} + + err := json.Unmarshal(msg, &inner) + if err != nil { + log.Printf("msal-upgrade: failed to unmarshal inner: %v", err) + continue + } + + updated := normalizeKeys(inner) + if !updated { + continue + } + + newMsg, err := json.Marshal(inner) + if err != nil { + log.Printf("msal-upgrade: failed to remarshal inner: %v", err) + continue + } + + c[contractKey] = json.RawMessage(newMsg) + } + } + + if newVal, err := json.Marshal(c); err == nil { + val = newVal + } else { + log.Printf("msal-upgrade: failed to remarshal msal cache: %v", err) + } + } else { + log.Printf("msal-upgrade: failed to unmarshal msal cache: %v", err) + } + // Replace the msal cache contents with the new value retrieved. if err := cache.Unmarshal(val); err != nil { return err @@ -52,6 +105,38 @@ func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, _ return a.cache.Set(cCurrentUserCacheKey, val) } +// Normalize keys by removing upper-case keys and replacing them with lower-case keys. +// In the case where a lower-case key and upper-case key exists, the lower-case key entry +// takes precedence. +func normalizeKeys(m map[string]json.RawMessage) (normalized bool) { + for k, v := range m { + if hasUpper(k) { + // An upper-case key entry exists. Delete it as it is no longer allowed. + delete(m, k) + + // If a lower-case key entry exists, that supersedes it and we are done. + // Otherwise, we can safely upgrade the cache entry by re-adding it with lower case. + lower := strings.ToLower(k) + if _, isLower := m[lower]; !isLower { + m[lower] = v + } + + normalized = true + } + } + + return normalized +} + +func hasUpper(s string) bool { + for _, r := range s { + if unicode.IsUpper(r) && unicode.IsLetter(r) { + return true + } + } + return false +} + type Cache interface { Read(key string) ([]byte, error) Set(key string, value []byte) error diff --git a/cli/azd/pkg/auth/cache_test.go b/cli/azd/pkg/auth/cache_test.go index 9e8f3871a3..cf89d6690d 100644 --- a/cli/azd/pkg/auth/cache_test.go +++ b/cli/azd/pkg/auth/cache_test.go @@ -5,6 +5,7 @@ package auth import ( "context" + "encoding/json" "math/rand" "testing" @@ -113,3 +114,108 @@ func TestCredentialCache(t *testing.T) { _, err = c.Read("nonExist") require.ErrorIs(t, err, errCacheKeyNotFound) } + +type mockContractHolder struct { + contract *mockContract +} + +// Marshal implements cache.Marshaler in msal/apps/cache. +func (c *mockContractHolder) Marshal() ([]byte, error) { + return json.Marshal(c.contract) +} + +// Unmarshal implements cache.Unmarshaler in msal/apps/cache. +func (c *mockContractHolder) Unmarshal(b []byte) error { + contract := &mockContract{} + + err := json.Unmarshal(b, contract) + if err != nil { + return err + } + + c.contract = contract + return nil +} + +type val struct { + Value string `json:"value"` +} + +// mockContract that simulates the MSAL cache contract. +type mockContract struct { + AccessTokens map[string]val `json:"AccessToken,omitempty"` + RefreshTokens map[string]val `json:"RefreshToken,omitempty"` + IDTokens map[string]val `json:"IdToken,omitempty"` + Accounts map[string]val `json:"Account,omitempty"` + AppMetaData map[string]val `json:"AppMetadata,omitempty"` + + // mock remainder fields + Remainder map[string]val `json:"Remainder,omitempty"` +} + +func TestKeyNormalization(t *testing.T) { + entries := map[string]val{ + "Upper": {"Upper"}, + "lower": {"lower"}, + "Upper-And-Lower": {"Upper-And-Lower"}, + "upper-and-lower": {"upper-and-lower"}, + } + orig := mockContract{ + AccessTokens: entries, + RefreshTokens: entries, + IDTokens: entries, + Accounts: entries, + AppMetaData: entries, + Remainder: map[string]val{ + "remainder": {"remainder"}, + }, + } + + normalizedEntries := map[string]val{ + "upper": {"Upper"}, + "lower": {"lower"}, + "upper-and-lower": {"upper-and-lower"}, + } + normalized := mockContract{ + AccessTokens: normalizedEntries, + RefreshTokens: normalizedEntries, + IDTokens: normalizedEntries, + Accounts: normalizedEntries, + AppMetaData: normalizedEntries, + Remainder: map[string]val{ + "remainder": {"remainder"}, + }, + } + + ctx := context.Background() + c := msalCacheAdapter{&memoryCache{ + cache: map[string][]byte{}, + inner: nil, + }} + + // Replace (retrieve) when cache is empty, expect nil + h := mockContractHolder{} + err := c.Replace(ctx, &h, cache.ReplaceHints{}) + require.NoError(t, err) + require.Nil(t, h.contract) + + // Export (save) with original entry + h.contract = &orig + err = c.Export(ctx, &h, cache.ExportHints{}) + require.NoError(t, err) + require.JSONEq(t, mustJson(orig), mustJson(h.contract)) + + // Replace (retrieve) that will normalize the keys + err = c.Replace(ctx, &h, cache.ReplaceHints{}) + require.NoError(t, err) + require.JSONEq(t, mustJson(normalized), mustJson(h.contract)) +} + +func mustJson(v any) string { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(b) +}