Skip to content

Commit

Permalink
updated to current vault plugin api
Browse files Browse the repository at this point in the history
  • Loading branch information
cypherhat committed Jul 16, 2020
1 parent 5ab8c8f commit ad07cb9
Show file tree
Hide file tree
Showing 14 changed files with 410 additions and 99 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Expand Up @@ -2,4 +2,6 @@
**/data
**/data/*.json
**/test
**/releases/*
**/releases/*
**/scripts
**/scripts/*.sh
35 changes: 18 additions & 17 deletions backend.go
Expand Up @@ -18,19 +18,16 @@ import (
"context"
"fmt"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)

// New returns a new backend as an interface. This func
// is only necessary for builtin backend plugins.
func New() (interface{}, error) {
return Backend(), nil
}

// Factory returns a new backend as logical.Backend.
// Factory returns the backend
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
Expand All @@ -40,18 +37,21 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
// FactoryType returns the factory
func FactoryType(backendType logical.BackendType) logical.Factory {
return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
b, err := Backend(conf)
if err != nil {
return nil, err
}
b.BackendType = backendType
if err := b.Setup(ctx, conf); err != nil {
if err = b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
}

// Backend returns the backend
func Backend() *backend {
var b backend
func Backend(conf *logical.BackendConfig) (*PluginBackend, error) {
var b PluginBackend
b.Backend = &framework.Backend{
Help: "",
Paths: framework.PathAppend(
Expand All @@ -76,14 +76,15 @@ func Backend() *backend {
Secrets: []*framework.Secret{},
BackendType: logical.TypeLogical,
}
return &b
return &b, nil
}

type backend struct {
// PluginBackend implements the Backend for this plugin
type PluginBackend struct {
*framework.Backend
}

func (b *backend) pathExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
func (b *PluginBackend) pathExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
out, err := req.Storage.Get(ctx, req.Path)
if err != nil {
return false, fmt.Errorf("existence check failed: %v", err)
Expand Down
15 changes: 15 additions & 0 deletions go.mod
@@ -0,0 +1,15 @@
module github.com/immutability-io/trustee

go 1.14

require (
github.com/btcsuite/btcd v0.20.1-beta
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/ethereum/go-ethereum v1.9.16
github.com/hashicorp/errwrap v1.0.0
github.com/hashicorp/vault/api v1.0.4
github.com/hashicorp/vault/sdk v0.1.13
github.com/pborman/uuid v1.2.0
github.com/sethvargo/go-diceware v0.2.0
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899
)
295 changes: 295 additions & 0 deletions go.sum

Large diffs are not rendered by default.

12 changes: 5 additions & 7 deletions main.go
Expand Up @@ -18,22 +18,20 @@ import (
"log"
"os"

"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/plugin"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/plugin"
)

func main() {
apiClientMeta := &pluginutil.APIClientMeta{}
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:]) // Ignore command, strictly parse flags

tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)

factoryFunc := FactoryType(logical.TypeLogical)
err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
BackendFactoryFunc: Factory,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions path_addresses.go
Expand Up @@ -18,16 +18,16 @@ import (
"context"
"fmt"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)

// AccountAddress stores the name of the account to allow reverse lookup by address
type AccountAddress struct {
Address string `json:"address"`
}

func addressesPaths(b *backend) []*framework.Path {
func addressesPaths(b *PluginBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "addresses/?",
Expand Down Expand Up @@ -80,7 +80,7 @@ func addressesPaths(b *backend) []*framework.Path {
}
}

func (b *backend) pathAddressesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAddressesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -104,7 +104,7 @@ func (b *backend) pathAddressesRead(ctx context.Context, req *logical.Request, d
}, nil
}

func (b *backend) pathAddressesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAddressesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -117,7 +117,7 @@ func (b *backend) pathAddressesList(ctx context.Context, req *logical.Request, d
return logical.ListResponse(vals), nil
}

func (b *backend) readAddress(ctx context.Context, req *logical.Request, address string) (*AccountNames, error) {
func (b *PluginBackend) readAddress(ctx context.Context, req *logical.Request, address string) (*AccountNames, error) {
path := fmt.Sprintf("addresses/%s", address)
entry, err := req.Storage.Get(ctx, path)
if err != nil {
Expand All @@ -136,7 +136,7 @@ func (b *backend) readAddress(ctx context.Context, req *logical.Request, address
return &accountNames, nil
}

func (b *backend) pathAddressesVerify(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAddressesVerify(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -158,7 +158,7 @@ func (b *backend) pathAddressesVerify(ctx context.Context, req *logical.Request,
return b.verifySignature(ctx, req, data, account.Names[0])
}

func (b *backend) crossReference(ctx context.Context, req *logical.Request, name, address string) error {
func (b *PluginBackend) crossReference(ctx context.Context, req *logical.Request, name, address string) error {
accountAddress := &AccountAddress{Address: address}
accountNames, err := b.readAddress(ctx, req, address)

Expand Down Expand Up @@ -191,7 +191,7 @@ func (b *backend) crossReference(ctx context.Context, req *logical.Request, name
return nil
}

func (b *backend) removeCrossReference(ctx context.Context, req *logical.Request, name, address string) error {
func (b *PluginBackend) removeCrossReference(ctx context.Context, req *logical.Request, name, address string) error {
pathAccountAddress := fmt.Sprintf("addresses/%s", address)
pathAccountName := fmt.Sprintf("names/%s", name)

Expand Down
22 changes: 11 additions & 11 deletions path_audience.go
Expand Up @@ -23,8 +23,8 @@ import (
jwt "github.com/dgrijalva/jwt-go"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)

// Audience is a public key known to vault. A Trustee has an address (Ethereum-compatible)
Expand All @@ -33,7 +33,7 @@ type Audience struct {
PublicKey string `json:"public_key"`
}

func audiencesPaths(b *backend) []*framework.Path {
func audiencesPaths(b *PluginBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "audiences/?",
Expand Down Expand Up @@ -70,7 +70,7 @@ Creates (or updates) an audience.
}
}

func (b *backend) pathAudiencesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAudiencesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -82,7 +82,7 @@ func (b *backend) pathAudiencesList(ctx context.Context, req *logical.Request, d
return logical.ListResponse(vals), nil
}

func (b *backend) pathAudiencesCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAudiencesCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -118,7 +118,7 @@ func (b *backend) pathAudiencesCreate(ctx context.Context, req *logical.Request,
}, nil
}

func (b *backend) pathAudiencesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAudiencesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -142,7 +142,7 @@ func (b *backend) pathAudiencesRead(ctx context.Context, req *logical.Request, d
}, nil
}

func (b *backend) readAudience(ctx context.Context, req *logical.Request, name string) (*Audience, error) {
func (b *PluginBackend) readAudience(ctx context.Context, req *logical.Request, name string) (*Audience, error) {
path := fmt.Sprintf("audiences/%s", name)
entry, err := req.Storage.Get(ctx, path)
if err != nil {
Expand All @@ -162,7 +162,7 @@ func (b *backend) readAudience(ctx context.Context, req *logical.Request, name s
return &audience, nil
}

func (b *backend) pathAudiencesDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathAudiencesDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -173,7 +173,7 @@ func (b *backend) pathAudiencesDelete(ctx context.Context, req *logical.Request,
return nil, nil
}

func (b *backend) pathEncryptForAudience(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathEncryptForAudience(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand All @@ -197,7 +197,7 @@ func (b *backend) pathEncryptForAudience(ctx context.Context, req *logical.Reque

}

func (b *backend) encryptForAudience(ctx context.Context, audience *Audience, plaintext string) (string, error) {
func (b *PluginBackend) encryptForAudience(ctx context.Context, audience *Audience, plaintext string) (string, error) {

publicKeyBytes, err := hex.DecodeString(audience.PublicKey)
if err != nil {
Expand All @@ -220,7 +220,7 @@ func (b *backend) encryptForAudience(ctx context.Context, audience *Audience, pl

}

func (b *backend) encryptClaims(ctx context.Context, audience *Audience, claims jwt.MapClaims) (jwt.MapClaims, error) {
func (b *PluginBackend) encryptClaims(ctx context.Context, audience *Audience, claims jwt.MapClaims) (jwt.MapClaims, error) {
encryptedClaims := make(jwt.MapClaims)
for key, value := range claims {
ciphertext, err := b.encryptForAudience(ctx, audience, value.(string))
Expand Down
18 changes: 9 additions & 9 deletions path_config.go
Expand Up @@ -19,16 +19,16 @@ import (
"fmt"

"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/cidrutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil"
"github.com/hashicorp/vault/sdk/logical"
)

type config struct {
BoundCIDRList []string `json:"bound_cidr_list_list" structs:"bound_cidr_list" mapstructure:"bound_cidr_list"`
}

func configPaths(b *backend) []*framework.Path {
func configPaths(b *PluginBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "config",
Expand All @@ -52,7 +52,7 @@ IP addresses which can perform the login operation.`,
}
}

func (b *backend) pathWriteConfig(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathWriteConfig(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var boundCIDRList []string
if boundCIDRListRaw, ok := data.GetOk("bound_cidr_list"); ok {
boundCIDRList = boundCIDRListRaw.([]string)
Expand All @@ -77,7 +77,7 @@ func (b *backend) pathWriteConfig(ctx context.Context, req *logical.Request, dat
}, nil
}

func (b *backend) pathReadConfig(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathReadConfig(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
configBundle, err := b.readConfig(ctx, req.Storage)
if err != nil {
return nil, err
Expand All @@ -96,7 +96,7 @@ func (b *backend) pathReadConfig(ctx context.Context, req *logical.Request, data
}

// Config returns the configuration for this backend.
func (b *backend) readConfig(ctx context.Context, s logical.Storage) (*config, error) {
func (b *PluginBackend) readConfig(ctx context.Context, s logical.Storage) (*config, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
Expand All @@ -116,7 +116,7 @@ func (b *backend) readConfig(ctx context.Context, s logical.Storage) (*config, e
return &result, nil
}

func (b *backend) configured(ctx context.Context, req *logical.Request) (*config, error) {
func (b *PluginBackend) configured(ctx context.Context, req *logical.Request) (*config, error) {
config, err := b.readConfig(ctx, req.Storage)
if err != nil {
return nil, err
Expand All @@ -128,7 +128,7 @@ func (b *backend) configured(ctx context.Context, req *logical.Request) (*config
return config, nil
}

func (b *backend) validIPConstraints(config *config, req *logical.Request) (bool, error) {
func (b *PluginBackend) validIPConstraints(config *config, req *logical.Request) (bool, error) {
if len(config.BoundCIDRList) != 0 {
if req.Connection == nil || req.Connection.RemoteAddr == "" {
return false, fmt.Errorf("failed to get connection information")
Expand Down
14 changes: 7 additions & 7 deletions path_import.go
Expand Up @@ -22,12 +22,12 @@ import (

"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/sha3"
)

func importPaths(b *backend) []*framework.Path {
func importPaths(b *PluginBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "import/" + framework.GenericNameRegex("name"),
Expand Down Expand Up @@ -56,12 +56,12 @@ Reads a JSON keystore, decrypts it and stores the passphrase.
}
}

func (b *backend) pathImportExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
func (b *PluginBackend) pathImportExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
trusteePath := strings.Replace(req.Path, RequestPathImport, RequestPathTrustees, -1)
return pathExists(ctx, req, trusteePath)
}

func (b *backend) pathImportCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
func (b *PluginBackend) pathImportCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
_, err := b.configured(ctx, req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -91,7 +91,7 @@ func (b *backend) pathImportCreate(ctx context.Context, req *logical.Request, da
publicKeyBytes := crypto.FromECDSAPub(publicKeyECDSA)
publicKeyString := hexutil.Encode(publicKeyBytes)[4:]

hash := sha3.NewKeccak256()
hash := sha3.NewLegacyKeccak256()
hash.Write(publicKeyBytes[1:])
address := hexutil.Encode(hash.Sum(nil)[12:])

Expand Down

0 comments on commit ad07cb9

Please sign in to comment.