Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add msal cache 1.0 -> 1.1 upgrade #2664

Merged
merged 2 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/azd/.vscode/cspell-azd-dictionary.txt
Expand Up @@ -115,6 +115,7 @@ pyapp
pyvenv
reauthentication
relogin
remarshal
restoreapp
retriable
rzip
Expand Down
85 changes: 85 additions & 0 deletions cli/azd/pkg/auth/cache.go
Expand Up @@ -5,11 +5,24 @@ package auth

import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"unicode"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
)

// Known entries from msal cache contract. This is not an exhaustive list.
var contractFields = []string{
"AccessToken",
"RefreshToken",
"IdToken",
"Account",
"AppMetadata",
}

// The MSAL cache key for the current user. The stored MSAL cached data contains
// all accounts with stored credentials, across all tenants.
// Currently, the underlying MSAL cache data is represented as [Contract] inside the library.
Expand All @@ -36,6 +49,46 @@ func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler,
return err
}

// In msal v1.0, keys were stored with mixed casing; in v1.1., it was changed to lower case.
// This handles upgrades where we have a v1.0 cache, and we need to convert it to v1.1,
// by normalizing the appropriate key entries.
c := map[string]json.RawMessage{}
if err = json.Unmarshal(val, &c); err == nil {
weikanglim marked this conversation as resolved.
Show resolved Hide resolved
for _, contractKey := range contractFields {
if _, found := c[contractKey]; found {
msg := []byte(c[contractKey])
inner := map[string]json.RawMessage{}

err := json.Unmarshal(msg, &inner)
if err != nil {
log.Printf("msal-upgrade: failed to unmarshal inner: %v", err)
continue
}

updated := normalizeKeys(inner)
if !updated {
continue
}

newMsg, err := json.Marshal(inner)
if err != nil {
log.Printf("msal-upgrade: failed to remarshal inner: %v", err)
continue
}

c[contractKey] = json.RawMessage(newMsg)
}
}

if newVal, err := json.Marshal(c); err == nil {
val = newVal
} else {
log.Printf("msal-upgrade: failed to remarshal msal cache: %v", err)
}
} else {
log.Printf("msal-upgrade: failed to unmarshal msal cache: %v", err)
}

// Replace the msal cache contents with the new value retrieved.
if err := cache.Unmarshal(val); err != nil {
return err
Expand All @@ -52,6 +105,38 @@ func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, _
return a.cache.Set(cCurrentUserCacheKey, val)
}

// Normalize keys by removing upper-case keys and replacing them with lower-case keys.
// In the case where a lower-case key and upper-case key exists, the lower-case key entry
// takes precedence.
func normalizeKeys(m map[string]json.RawMessage) (normalized bool) {
for k, v := range m {
if hasUpper(k) {
weikanglim marked this conversation as resolved.
Show resolved Hide resolved
// An upper-case key entry exists. Delete it as it is no longer allowed.
delete(m, k)

// If a lower-case key entry exists, that supersedes it and we are done.
// Otherwise, we can safely upgrade the cache entry by re-adding it with lower case.
lower := strings.ToLower(k)
if _, isLower := m[lower]; !isLower {
m[lower] = v
}

normalized = true
}
}

return normalized
}

func hasUpper(s string) bool {
weikanglim marked this conversation as resolved.
Show resolved Hide resolved
for _, r := range s {
if unicode.IsUpper(r) && unicode.IsLetter(r) {
return true
}
}
return false
}

type Cache interface {
Read(key string) ([]byte, error)
Set(key string, value []byte) error
Expand Down
106 changes: 106 additions & 0 deletions cli/azd/pkg/auth/cache_test.go
Expand Up @@ -5,6 +5,7 @@ package auth

import (
"context"
"encoding/json"
"math/rand"
"testing"

Expand Down Expand Up @@ -113,3 +114,108 @@ func TestCredentialCache(t *testing.T) {
_, err = c.Read("nonExist")
require.ErrorIs(t, err, errCacheKeyNotFound)
}

type mockContractHolder struct {
contract *mockContract
}

// Marshal implements cache.Marshaler in msal/apps/cache.
func (c *mockContractHolder) Marshal() ([]byte, error) {
return json.Marshal(c.contract)
}

// Unmarshal implements cache.Unmarshaler in msal/apps/cache.
func (c *mockContractHolder) Unmarshal(b []byte) error {
contract := &mockContract{}

err := json.Unmarshal(b, contract)
if err != nil {
return err
}

c.contract = contract
return nil
}

type val struct {
Value string `json:"value"`
}

// mockContract that simulates the MSAL cache contract.
type mockContract struct {
AccessTokens map[string]val `json:"AccessToken,omitempty"`
RefreshTokens map[string]val `json:"RefreshToken,omitempty"`
IDTokens map[string]val `json:"IdToken,omitempty"`
Accounts map[string]val `json:"Account,omitempty"`
AppMetaData map[string]val `json:"AppMetadata,omitempty"`

// mock remainder fields
Remainder map[string]val `json:"Remainder,omitempty"`
}

func TestKeyNormalization(t *testing.T) {
entries := map[string]val{
"Upper": {"Upper"},
"lower": {"lower"},
"Upper-And-Lower": {"Upper-And-Lower"},
"upper-and-lower": {"upper-and-lower"},
}
orig := mockContract{
AccessTokens: entries,
RefreshTokens: entries,
IDTokens: entries,
Accounts: entries,
AppMetaData: entries,
Remainder: map[string]val{
"remainder": {"remainder"},
},
}

normalizedEntries := map[string]val{
"upper": {"Upper"},
"lower": {"lower"},
"upper-and-lower": {"upper-and-lower"},
}
normalized := mockContract{
AccessTokens: normalizedEntries,
RefreshTokens: normalizedEntries,
IDTokens: normalizedEntries,
Accounts: normalizedEntries,
AppMetaData: normalizedEntries,
Remainder: map[string]val{
"remainder": {"remainder"},
},
}

ctx := context.Background()
c := msalCacheAdapter{&memoryCache{
cache: map[string][]byte{},
inner: nil,
}}

// Replace (retrieve) when cache is empty, expect nil
h := mockContractHolder{}
err := c.Replace(ctx, &h, cache.ReplaceHints{})
require.NoError(t, err)
require.Nil(t, h.contract)

// Export (save) with original entry
h.contract = &orig
err = c.Export(ctx, &h, cache.ExportHints{})
require.NoError(t, err)
require.JSONEq(t, mustJson(orig), mustJson(h.contract))

// Replace (retrieve) that will normalize the keys
err = c.Replace(ctx, &h, cache.ReplaceHints{})
require.NoError(t, err)
require.JSONEq(t, mustJson(normalized), mustJson(h.contract))
}

func mustJson(v any) string {
b, err := json.Marshal(v)
if err != nil {
panic(err)
}

return string(b)
}