diff --git a/pkg/ssh_agent/ssh_agent.go b/pkg/ssh_agent/ssh_agent.go index 05ccb138e2..a79cc54fff 100644 --- a/pkg/ssh_agent/ssh_agent.go +++ b/pkg/ssh_agent/ssh_agent.go @@ -8,12 +8,15 @@ import ( "net" "os" "path/filepath" + "strings" uuid "github.com/satori/go.uuid" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/terminal" "github.com/werf/logboek" + secret_common "github.com/werf/werf/cmd/werf/helm/secret/common" "github.com/werf/werf/pkg/util" "github.com/werf/werf/pkg/werf" ) @@ -28,24 +31,93 @@ func setupProcessSSHAgent(sshAuthSock string) error { return os.Setenv("SSH_AUTH_SOCK", SSHAuthSock) } -func Init(ctx context.Context, keys []string) error { - for _, key := range keys { - if keyExists, err := util.FileExists(key); !keyExists { - return fmt.Errorf("specified ssh key %s does not exist", key) - } else if err != nil { - return fmt.Errorf("specified ssh key %s does not exist: %v", key, err) +type sshKeyConfig struct { + FilePath string + Passphrase []byte +} + +type sshKey struct { + Config sshKeyConfig + PrivateKey interface{} +} + +func getSshKeyConfig(path string) (sshKeyConfig, error) { + var filePath string + var passphrase []byte + + switch { + case strings.HasPrefix(path, "file://"): + userinfoWithPath := strings.TrimPrefix(path, "file://") + + parts := strings.SplitN(userinfoWithPath, "@", 2) + passphrase = []byte(parts[0]) + filePath = util.ExpandPath(parts[1]) + + default: + filePath = util.ExpandPath(path) + } + + if keyExists, err := util.FileExists(filePath); !keyExists { + return sshKeyConfig{}, fmt.Errorf("specified ssh key does not exist") + } else if err != nil { + return sshKeyConfig{}, fmt.Errorf("specified ssh key does not exist: %w", err) + } + + return sshKeyConfig{FilePath: filePath, Passphrase: passphrase}, nil +} + +type loadSshKeysOptions struct { + WarnInvalidKeys bool +} + +func loadSshKeys(ctx context.Context, configs []sshKeyConfig, opts loadSshKeysOptions) ([]sshKey, error) { + var res []sshKey + + for _, cfg := range configs { + sshKey, err := parsePrivateSSHKey(cfg) + if err != nil { + if opts.WarnInvalidKeys { + logboek.Context(ctx).Warn().LogF("WARNING: unable to parse ssh key %s: %s\n", cfg.FilePath, err) + continue + } else { + return nil, fmt.Errorf("unable to parse ssh key %s: %s", cfg.FilePath, err) + } } + + res = append(res, sshKey) } - if len(keys) > 0 { - agentSock, err := runSSHAgentWithKeys(ctx, keys) + return res, nil +} + +func Init(ctx context.Context, userKeys []string) error { + var configs []sshKeyConfig + + for _, key := range userKeys { + cfg, err := getSshKeyConfig(key) if err != nil { - return fmt.Errorf("unable to run ssh agent with specified keys: %s", err) + return fmt.Errorf("unable to get ssh key %s config: %s", key, err) } - if err := setupProcessSSHAgent(agentSock); err != nil { - return fmt.Errorf("unable to init ssh auth socket to %q: %s", agentSock, err) + + configs = append(configs, cfg) + } + + if len(configs) > 0 { + keys, err := loadSshKeys(ctx, configs, loadSshKeysOptions{}) + if err != nil { + return fmt.Errorf("unable to load ssh keys: %s", err) + } + + if len(keys) > 0 { + agentSock, err := runSSHAgentWithKeys(ctx, keys) + if err != nil { + return fmt.Errorf("unable to run ssh agent with specified keys: %s", err) + } + if err := setupProcessSSHAgent(agentSock); err != nil { + return fmt.Errorf("unable to init ssh auth socket to %q: %s", agentSock, err) + } + return nil } - return nil } systemAgentSock := os.Getenv("SSH_AUTH_SOCK") @@ -56,35 +128,23 @@ func Init(ctx context.Context, keys []string) error { return nil } - var defaultKeys []string + var defaultConfigs []*sshKeyConfig for _, defaultFileName := range []string{"id_rsa", "id_dsa"} { path := filepath.Join(os.Getenv("HOME"), ".ssh", defaultFileName) if keyExists, _ := util.FileExists(path); keyExists { - defaultKeys = append(defaultKeys, path) + defaultConfigs = append(defaultConfigs, &sshKeyConfig{FilePath: path}) } } - if len(defaultKeys) > 0 { - var validKeys []string - - for _, key := range defaultKeys { - keyData, err := ioutil.ReadFile(key) - if err != nil { - logboek.Context(ctx).Warn().LogF("WARNING: cannot read default key %s: %s\n", key, err) - continue - } - _, err = ssh.ParseRawPrivateKey(keyData) - if err != nil { - logboek.Context(ctx).Warn().LogF("WARNING: default key %s validation error: %s\n", key, err) - continue - } - - validKeys = append(validKeys, key) + if len(defaultConfigs) > 0 { + keys, err := loadSshKeys(ctx, configs, loadSshKeysOptions{WarnInvalidKeys: true}) + if err != nil { + return fmt.Errorf("unable to load ssh keys: %s", err) } - if len(validKeys) > 0 { - agentSock, err := runSSHAgentWithKeys(ctx, validKeys) + if len(keys) > 0 { + agentSock, err := runSSHAgentWithKeys(ctx, keys) if err != nil { return fmt.Errorf("unable to run ssh agent with specified keys: %s", err) } @@ -108,7 +168,7 @@ func Terminate() error { return nil } -func runSSHAgentWithKeys(ctx context.Context, keys []string) (string, error) { +func runSSHAgentWithKeys(ctx context.Context, keys []sshKey) (string, error) { agentSock, err := runSSHAgent(ctx) if err != nil { return "", fmt.Errorf("error running ssh agent: %s", err) @@ -117,7 +177,7 @@ func runSSHAgentWithKeys(ctx context.Context, keys []string) (string, error) { for _, key := range keys { err := addSSHKey(ctx, agentSock, key) if err != nil { - return "", fmt.Errorf("error adding ssh key %s: %s", key, err) + return "", fmt.Errorf("error adding ssh key: %s", err) } } @@ -171,7 +231,7 @@ func runSSHAgent(ctx context.Context) (string, error) { return sockPath, nil } -func addSSHKey(ctx context.Context, authSock string, key string) error { +func addSSHKey(ctx context.Context, authSock string, key sshKey) error { conn, err := net.Dial("unix", authSock) if err != nil { return fmt.Errorf("error dialing with ssh agent %s: %s", authSock, err) @@ -180,22 +240,52 @@ func addSSHKey(ctx context.Context, authSock string, key string) error { agentClient := agent.NewClient(conn) - keyData, err := ioutil.ReadFile(key) + err = agentClient.Add(agent.AddedKey{PrivateKey: key.PrivateKey}) if err != nil { - return fmt.Errorf("error reading key file %s: %s", key, err) + return err } - privateKey, err := ssh.ParseRawPrivateKey(keyData) + logboek.Context(ctx).Info().LogF("Added private key %s to ssh agent %s\n", key.Config.FilePath, authSock) + + return nil +} + +func parsePrivateSSHKey(cfg sshKeyConfig) (sshKey, error) { + keyData, err := ioutil.ReadFile(cfg.FilePath) if err != nil { - return fmt.Errorf("error parsing private key %s: %s", key, err) + return sshKey{}, fmt.Errorf("error reading key file %q: %s", cfg.FilePath, err) } - err = agentClient.Add(agent.AddedKey{PrivateKey: privateKey}) + var privateKey interface{} + + privateKey, err = ssh.ParseRawPrivateKey(keyData) if err != nil { - return err - } + switch err.(type) { + case *ssh.PassphraseMissingError: + var passphrase []byte + if len(cfg.Passphrase) == 0 { + if terminal.IsTerminal(int(os.Stdin.Fd())) { + if data, err := secret_common.InputFromInteractiveStdin(fmt.Sprintf("Enter passphrase for ssh key %s: ", cfg.FilePath)); err != nil { + return sshKey{}, fmt.Errorf("error getting passphrase for ssh key %s: %s", cfg.FilePath, err) + } else { + passphrase = data + } + } else { + return sshKey{}, fmt.Errorf(`%w: please provide passphrase using --ssh-add="file://PASSPHRASE@FILEPATH" format`, err) + } + } else { + passphrase = cfg.Passphrase + } - logboek.Context(ctx).Info().LogF("Added private key %s to ssh agent %s\n", key, authSock) + privateKey, err = ssh.ParseRawPrivateKeyWithPassphrase(keyData, passphrase) + if err != nil { + return sshKey{}, fmt.Errorf("error parsing private key %s: %s", cfg.FilePath, err) + } - return nil + default: + return sshKey{}, fmt.Errorf("error parsing private key %s: %s", cfg.FilePath, err) + } + } + + return sshKey{Config: cfg, PrivateKey: privateKey}, nil }