Skip to content

Commit

Permalink
Add auth login subcommand (#48)
Browse files Browse the repository at this point in the history
* Move version to internal package

* Add internal config package

* Add internal client package

* Refactor subcommands to use internal client constructor

* Add VerifyToken method

* Add login subcommand

* Update demo tape
  • Loading branch information
mattt committed Dec 14, 2023
1 parent 8123d2b commit 7aaff16
Show file tree
Hide file tree
Showing 22 changed files with 371 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -20,7 +20,7 @@ all: replicate

replicate:
CGO_ENABLED=0 $(GO) build -o $@ \
-ldflags "-X github.com/replicate/cli/internal/cmd.version=$(REPLICATE_CLI_VERSION) -w" \
-ldflags "-X github.com/replicate/cli/internal.version=$(REPLICATE_CLI_VERSION) -w" \
main.go

demo.gif: replicate demo.tape
Expand Down
5 changes: 4 additions & 1 deletion cmd/replicate/main.go
Expand Up @@ -5,7 +5,9 @@ import (

"github.com/spf13/cobra"

"github.com/replicate/cli/internal"
"github.com/replicate/cli/internal/cmd"
"github.com/replicate/cli/internal/cmd/auth"
"github.com/replicate/cli/internal/cmd/hardware"
"github.com/replicate/cli/internal/cmd/model"
"github.com/replicate/cli/internal/cmd/prediction"
Expand All @@ -15,7 +17,7 @@ import (
// rootCmd represents the base command when called without any subcommands
var rootCmd = &cobra.Command{
Use: "replicate",
Version: cmd.Version(),
Version: internal.Version(),
}

// Execute adds all child commands to the root command and sets flags appropriately.
Expand All @@ -33,6 +35,7 @@ func init() {
Title: "Core commands:",
})
for _, cmd := range []*cobra.Command{
auth.RootCmd,
model.RootCmd,
prediction.RootCmd,
training.RootCmd,
Expand Down
Binary file modified demo.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions demo.tape
Expand Up @@ -8,16 +8,18 @@ Set FontSize 24
Set Width 1200
Set Height 600

Type "export REPLICATE_API_TOKEN="
Type "echo "
Sleep 100ms
Hide
Type "r8_•••••••••••••••••••••••••••••••••••••"
Show
Sleep 100ms
Type " | replicate auth login"
Sleep 100ms
Ctrl+C # Don't actually set the API key
Sleep 1s

Type 'replicate stream meta/llama-2-70b-chat \'
Type 'replicate run meta/llama-2-70b-chat \'
Enter
Type@50ms ' prompt="write a haiku about corgis"'
Enter
Expand All @@ -34,4 +36,4 @@ Enter
Type@50ms ' width=512 height=512 seed=42069'
Enter

Sleep 15s
Sleep 30s
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -20,6 +20,7 @@ require (
github.com/stretchr/testify v1.8.4
go.uber.org/nilaway v0.0.0-20231130193605-0ef3071d1630
golang.org/x/sync v0.5.0
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand Down Expand Up @@ -216,7 +217,6 @@ require (
google.golang.org/protobuf v1.28.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
honnef.co/go/tools v0.4.6 // indirect
mvdan.cc/gofumpt v0.5.0 // indirect
mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed // indirect
Expand Down
21 changes: 7 additions & 14 deletions go.sum
Expand Up @@ -522,8 +522,6 @@ github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 h1:TCg2WBOl
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727/go.mod h1:rlzQ04UMyJXu/aOvhd8qT+hvDrFpiwqp8MRXDY9szc0=
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 h1:M8mH9eK4OUR4lu7Gd+PU1fV2/qnDNfzT635KRSObncs=
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8A4Y+GyBgPuaQJuWiy0XYftx4Xm/y5Jqk9I6VQ=
github.com/replicate/replicate-go v0.13.2 h1:S+ENs0cKMlizZzh9Ht/Diy66FCPKSdXNRh/9QvKyf+8=
github.com/replicate/replicate-go v0.13.2/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.14.2 h1:XgK+REvYrWs7qDeyugxHA93h31qBhEFk/3p1/p2w3W8=
github.com/replicate/replicate-go v0.14.2/go.mod h1:otIrl1vDmyjNhTzmVmp/mQU3Wt1+3387gFNEsAZq0ig=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
Expand Down Expand Up @@ -662,8 +660,7 @@ go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/nilaway v0.0.0-20231117175943-a267567c6fff h1:MSLkMoDm4RpwG6QRJPaiNK4dAnFMX81SmEaRBSs/zws=
go.uber.org/nilaway v0.0.0-20231117175943-a267567c6fff/go.mod h1:5u4JTf2doTUy5fmx1tiTc8YUDP+F9LW+MjtXwuSePDk=
go.uber.org/nilaway v0.0.0-20231130193605-0ef3071d1630 h1:d78bN/STgxmCT0mDg+ZgR9EB3Z/lKq2i34sCvw39eMQ=
go.uber.org/nilaway v0.0.0-20231130193605-0ef3071d1630/go.mod h1:5q7m8ZeGRjsKsOAcEP1CxOPLAKA7lznQwHY1+Vex0Yo=
go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60=
go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
Expand All @@ -687,8 +684,7 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw=
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
Expand Down Expand Up @@ -724,8 +720,7 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -768,8 +763,8 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg=
golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand Down Expand Up @@ -854,8 +849,7 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
Expand Down Expand Up @@ -950,8 +944,7 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k=
golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/tools v0.15.0 h1:zdAyfUGbYmuVokhzVmghFl2ZJh5QhcfebBgmVPFYA+8=
golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
52 changes: 52 additions & 0 deletions internal/client/client.go
@@ -0,0 +1,52 @@
package client

import (
"context"
"fmt"

"github.com/replicate/cli/internal"
"github.com/replicate/cli/internal/config"
"github.com/replicate/replicate-go"
)

func NewClient(opts ...replicate.ClientOption) (*replicate.Client, error) {
baseURL := config.GetAPIBaseURL()

token, err := config.GetAPIToken()
if err != nil {
return nil, fmt.Errorf("failed to get API token: %w", err)
}
if token == "" {
return nil, fmt.Errorf("please authenticate with `replicate login`")
}

userAgent := fmt.Sprintf("replicate/%s", internal.Version())

opts = append([]replicate.ClientOption{
replicate.WithBaseURL(baseURL),
replicate.WithToken(token),
replicate.WithUserAgent(userAgent),
}, opts...)

r8, err := replicate.NewClient(opts...)
if err != nil {
return nil, err
}

return r8, nil
}

func VerifyToken(ctx context.Context, token string) (bool, error) {
r8, err := NewClient(replicate.WithToken(token))
if err != nil {
return false, err
}

// FIXME: Add better endpoint for verifying token
_, err = r8.ListHardware(ctx)
if err != nil {
return false, nil
}

return true, nil
}
75 changes: 75 additions & 0 deletions internal/cmd/auth/login.go
@@ -0,0 +1,75 @@
package auth

import (
"fmt"
"io"
"os"
"strings"

"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/config"
"github.com/spf13/cobra"
)

// loginCmd represents the login command
var loginCmd = &cobra.Command{
Use: "login --token-stdin",
Short: "Log in to Replicate",
Long: `Log in to Replicate
You can find your Replicate API token at https://replicate.com/account`,
Example: `
# Log in with environment variable
$ echo $REPLICATE_API_TOKEN | replicate auth login --token-stdin
# Log in with token file
$ replicate auth login --token-stdin < path/to/token`,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

tokenStdin, err := cmd.Flags().GetBool("token-stdin")
if err != nil {
return err
}

var token string
if tokenStdin {
token, err = readTokenFromStdin()
if err != nil {
return fmt.Errorf("failed to read token from stdin: %w", err)
}
} else {
return fmt.Errorf("token must be passed to stdin with --token-stdin flag")
}
token = strings.TrimSpace(token)

ok, err := client.VerifyToken(ctx, token)
if err != nil {
return fmt.Errorf("error verifying token: %w", err)
}
if !ok {
return fmt.Errorf("invalid token")
}

if err := config.SetAPIToken(token); err != nil {
return fmt.Errorf("failed to set API token: %w", err)
}

fmt.Println("Login successful.")

return nil
},
}

func readTokenFromStdin() (string, error) {
tokenBytes, err := io.ReadAll(os.Stdin)
if err != nil {
return "", fmt.Errorf("Failed to read token from stdin: %w", err)
}
return string(tokenBytes), nil
}

func init() {
loginCmd.Flags().Bool("token-stdin", false, "Take the token from stdin.")
_ = loginCmd.MarkFlagRequired("token-stdin")
}
23 changes: 23 additions & 0 deletions internal/cmd/auth/root.go
@@ -0,0 +1,23 @@
package auth

import (
"github.com/spf13/cobra"
)

var RootCmd = &cobra.Command{
Use: "auth [subcommand]",
Short: "Authenticate with Replicate",
}

func init() {
RootCmd.AddGroup(&cobra.Group{
ID: "subcommand",
Title: "Subcommands:",
})
for _, cmd := range []*cobra.Command{
loginCmd,
} {
RootCmd.AddCommand(cmd)
cmd.GroupID = "subcommand"
}
}
8 changes: 4 additions & 4 deletions internal/cmd/hardware/list.go
Expand Up @@ -7,8 +7,8 @@ import (
"github.com/cli/browser"
"github.com/spf13/cobra"

"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/util"
"github.com/replicate/replicate-go"
)

// listCmd represents the list hardware command
Expand All @@ -32,12 +32,12 @@ var listCmd = &cobra.Command{
return nil
}

client, err := replicate.NewClient(replicate.WithTokenFromEnv())
r8, err := client.NewClient()
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
return err
}

hardware, err := client.ListHardware(ctx)
hardware, err := r8.ListHardware(ctx)
if err != nil {
return fmt.Errorf("failed to list hardware: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions internal/cmd/model/create.go
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/cli/browser"
"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/identifier"
"github.com/replicate/cli/internal/util"
"github.com/replicate/replicate-go"
Expand Down Expand Up @@ -46,12 +47,12 @@ var createCmd = &cobra.Command{
}
}

client, err := replicate.NewClient(replicate.WithTokenFromEnv())
r8, err := client.NewClient()
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
return err
}

model, err := client.CreateModel(cmd.Context(), id.Owner, id.Name, *opts)
model, err := r8.CreateModel(cmd.Context(), id.Owner, id.Name, *opts)
if err != nil {
return fmt.Errorf("failed to create model: %w", err)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/cmd/model/schema.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"

"github.com/replicate/cli/internal/client"
"github.com/replicate/cli/internal/identifier"
"github.com/replicate/cli/internal/util"

Expand All @@ -24,14 +25,14 @@ var schemaCmd = &cobra.Command{

ctx := cmd.Context()

client, err := replicate.NewClient(replicate.WithTokenFromEnv())
r8, err := client.NewClient()
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
return err
}

var version *replicate.ModelVersion
if id.Version == "" {
model, err := client.GetModel(ctx, id.Owner, id.Name)
model, err := r8.GetModel(ctx, id.Owner, id.Name)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
Expand All @@ -42,7 +43,7 @@ var schemaCmd = &cobra.Command{

version = model.LatestVersion
} else {
version, err = client.GetModelVersion(ctx, id.Owner, id.Name, id.Version)
version, err = r8.GetModelVersion(ctx, id.Owner, id.Name, id.Version)
if err != nil {
return fmt.Errorf("failed to get model version: %w", err)
}
Expand Down

0 comments on commit 7aaff16

Please sign in to comment.