Skip to content

Commit

Permalink
fix: renew token of vault k8s auth method
Browse files Browse the repository at this point in the history
  • Loading branch information
Kryvchun committed May 18, 2023
1 parent f2d85d7 commit aa9b02c
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 7 deletions.
51 changes: 46 additions & 5 deletions dependency/vault_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"github.com/pkg/errors"
)

// VaultTokenRefreshCurrent tells to refresh the current client token.
const VaultTokenRefreshCurrent = ""

// Ensure implements
var _ Dependency = (*VaultTokenQuery)(nil)

Expand All @@ -16,6 +19,8 @@ type VaultTokenQuery struct {
stopCh chan struct{}
secret *Secret
vaultSecret *api.Secret

initialToken string
}

// NewVaultTokenQuery creates a new dependency.
Expand All @@ -28,9 +33,10 @@ func NewVaultTokenQuery(token string) (*VaultTokenQuery, error) {
},
}
return &VaultTokenQuery{
stopCh: make(chan struct{}, 1),
vaultSecret: vaultSecret,
secret: transformSecret(vaultSecret),
stopCh: make(chan struct{}, 1),
vaultSecret: vaultSecret,
secret: transformSecret(vaultSecret),
initialToken: token,
}, nil
}

Expand All @@ -43,8 +49,16 @@ func (d *VaultTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions,
default:
}

if vaultSecretRenewable(d.secret) {
err := renewSecret(clients, d)
var currentRenewer renewer = d

if d.initialToken == VaultTokenRefreshCurrent {
currentRenewer = newVaultSecretsOverrideRenewer(d, clients.Vault().Token())
}

secret, _ := currentRenewer.secrets()

if vaultSecretRenewable(secret) {
err := renewSecret(clients, currentRenewer)
if err != nil {
return nil, nil, errors.Wrap(err, d.String())
}
Expand Down Expand Up @@ -80,3 +94,30 @@ func (d *VaultTokenQuery) String() string {
func (d *VaultTokenQuery) Type() Type {
return TypeVault
}

func newVaultSecretsOverrideRenewer(parent renewer, token string) *vaultSecretsOverrideRenewer {
vaultSecret := &api.Secret{
Auth: &api.SecretAuth{
ClientToken: token,
Renewable: true,
LeaseDuration: 1,
},
}

return &vaultSecretsOverrideRenewer{
renewer: parent,
vaultSecret: vaultSecret,
secret: transformSecret(vaultSecret),
}
}

type vaultSecretsOverrideRenewer struct {
renewer

secret *Secret
vaultSecret *api.Secret
}

func (d *vaultSecretsOverrideRenewer) secrets() (*Secret, *api.Secret) {
return d.secret, d.vaultSecret
}
20 changes: 20 additions & 0 deletions dependency/vault_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewVaultTokenQuery(t *testing.T) {
Expand Down Expand Up @@ -76,3 +77,22 @@ func TestVaultTokenQuery_String(t *testing.T) {
})
}
}

func TestTewVaultSecretsOverrideRenewer(t *testing.T) {
const token = "expected_token"

parent, err := NewVaultTokenQuery(VaultTokenRefreshCurrent)
require.NoError(t, err)

vaultTokenSecretsOverride := newVaultSecretsOverrideRenewer(parent, token)

secret, vaultSecret := vaultTokenSecretsOverride.secrets()

if assert.NotNil(t, secret) && assert.NotNil(t, secret.Auth) {
assert.Equal(t, token, secret.Auth.ClientToken)
}

if assert.NotNil(t, vaultSecret) && assert.NotNil(t, vaultSecret.Auth) {
assert.Equal(t, token, vaultSecret.Auth.ClientToken)
}
}
33 changes: 31 additions & 2 deletions watch/vault_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,36 @@ import (
"github.com/hashicorp/vault/api"
)

func newVaultK8SAuthMethodRefreshTokenWatcher(
clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{},
) (*Watcher, error) {
isK8SAuthMethod := config.StringVal(c.K8SServiceAccountToken) != "" || config.StringVal(c.K8SServiceAccountTokenPath) != ""

if !isK8SAuthMethod || !config.BoolVal(c.RenewToken) {
return nil, nil
}

watcher := NewWatcher(&NewWatcherInput{
Clients: clients,
RetryFuncVault: RetryFunc(c.Retry.RetryFunc()),
})

vaultQuery, err := dep.NewVaultTokenQuery(dep.VaultTokenRefreshCurrent)
if err != nil {
watcher.Stop()

return nil, fmt.Errorf("vaultwatcher: %w", err)
}

if _, err := watcher.Add(vaultQuery); err != nil {
watcher.Stop()

return nil, fmt.Errorf("vaultwatcher: %w", err)
}

return watcher, nil
}

// VaultTokenWatcher monitors the vault token for updates
func VaultTokenWatcher(
clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{},
Expand All @@ -24,7 +54,7 @@ func VaultTokenWatcher(
// tokens are not being used.
raw_token := strings.TrimSpace(config.StringVal(c.Token))
if raw_token == "" {
return nil, nil
return newVaultK8SAuthMethodRefreshTokenWatcher(clients, c, doneCh)
}

unwrap := config.BoolVal(c.UnwrapToken)
Expand Down Expand Up @@ -76,7 +106,6 @@ func VaultTokenWatcher(

return watcher, nil
}

func watchTokenFile(
w *Watcher, tokenFile, raw_token string, unwrap bool, doneCh chan struct{},
) (func(), error) {
Expand Down
31 changes: 31 additions & 0 deletions watch/vault_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/hashicorp/consul-template/config"
dep "github.com/hashicorp/consul-template/dependency"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// approle auto-auth setup in watch_test.go, TestMain()
Expand Down Expand Up @@ -141,6 +143,35 @@ func TestVaultTokenWatcher(t *testing.T) {
// give it a chance to throw an error
}
})

t.Run("renew_kubernetes", func(t *testing.T) {
// Check that there is an attempt to refresh token.
testClients.Vault().SetToken(vaultToken)

_, err := testClients.Vault().Auth().Token().Create(
&api.TokenCreateRequest{
ID: "c_token",
TTL: "1m",
Renewable: config.Bool(true),
})
require.NoError(t, err)

conf := config.DefaultVaultConfig()
conf.Token = config.String("")
conf.RenewToken = config.Bool(true)
conf.K8SServiceAccountToken = config.String("any_k8s_token")

watcher, err := VaultTokenWatcher(testClients, conf, nil)
require.NoError(t, err)

defer watcher.Stop()

select {
case err := <-watcher.ErrCh():
assert.ErrorIs(t, err, dep.ErrLeaseExpired)
case <-time.After(time.Millisecond * 100):
}
})
}

func TestVaultTokenRefreshToken(t *testing.T) {
Expand Down

0 comments on commit aa9b02c

Please sign in to comment.