Skip to content

Commit

Permalink
Add models create subcommand (#30)
Browse files Browse the repository at this point in the history
* Upgrade replicate-go dependency

* Upgrade toolchain to Go 1.21

* Add hardware list command

* Add missing usage to models show subcommand

* Add models create subcommand

* Update go.work.sum
  • Loading branch information
mattt committed Nov 12, 2023
1 parent f9d688f commit e63f2ae
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cmd/replicate/main.go
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/spf13/cobra"

"github.com/replicate/cli/internal/cmd"
"github.com/replicate/cli/internal/cmd/hardware"
"github.com/replicate/cli/internal/cmd/model"
"github.com/replicate/cli/internal/cmd/prediction"
"github.com/replicate/cli/internal/cmd/training"
Expand Down Expand Up @@ -35,6 +36,7 @@ func init() {
model.RootCmd,
prediction.RootCmd,
training.RootCmd,
hardware.RootCmd,
} {
rootCmd.AddCommand(cmd)
cmd.GroupID = "core"
Expand Down
6 changes: 4 additions & 2 deletions go.mod
@@ -1,6 +1,8 @@
module github.com/replicate/cli

go 1.20
go 1.21

toolchain go1.21.1

require (
github.com/PaesslerAG/jsonpath v0.1.1
Expand All @@ -12,7 +14,7 @@ require (
github.com/getkin/kin-openapi v0.120.0
github.com/golangci/golangci-lint v1.55.2
github.com/mattn/go-isatty v0.0.20
github.com/replicate/replicate-go v0.10.0
github.com/replicate/replicate-go v0.11.0
github.com/schollz/progressbar/v3 v3.13.1
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -507,6 +507,8 @@ github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 h1:M8mH9eK4OUR4l
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8A4Y+GyBgPuaQJuWiy0XYftx4Xm/y5Jqk9I6VQ=
github.com/replicate/replicate-go v0.10.0 h1:01G4TVT1CokgWHvty+WaWX9Cl7fQMNRCjddumJpXEW0=
github.com/replicate/replicate-go v0.10.0/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.11.0 h1:zjjdXVvot2TqtiL8usgzdr6CEhrcI51DhygC2Gw/qjE=
github.com/replicate/replicate-go v0.11.0/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
Expand Down
4 changes: 3 additions & 1 deletion go.work
@@ -1,3 +1,5 @@
go 1.20
go 1.21

toolchain go1.21.1

use .
32 changes: 32 additions & 0 deletions go.work.sum
@@ -1,29 +1,61 @@
cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY=
github.com/armon/go-metrics v0.3.10/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4QAOwNTFc=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cristalhq/acmd v0.11.1/go.mod h1:LG5oa43pE/BbxtfMoImHCQN++0Su7dzipdgBjMCBVDQ=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-toolsmith/pkgload v1.2.2/go.mod h1:R2hxLNRKuAsiXCo2i5J6ZQPhnPMOVtU+f0arbFPWCus=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gostaticanalysis/testutil v0.4.0/go.mod h1:bLIoPefWXrRi/ssLFWX1dx7Repi5x3CuviD3dgAZaBU=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.2.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60=
github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517/go.mod h1:KQ7+USdGKfpPjXk4Ga+5XxQM4Lm4e3gAogrreFAYpOg=
github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg=
github.com/mozilla/tls-observatory v0.0.0-20210609171429-7bc42856d2e5/go.mod h1:FUqVoUPHSEdDR0MnFM3Dh8AU0pZHLXUD127SAJGER/s=
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d/go.mod h1:3OzsM7FXDQlpCiw2j81fOmAwQLnZnLGXVKUzeKQXIAw=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71/go.mod h1:4cgAphtvu7Ftv7vOT2ZOYhC6CvBxZixcasr8qIOTA50=
github.com/replicate/replicate-go v0.8.1 h1:Mza5hWR/R1akZRKwXtA/CQJ2pY4/B9fSCYX+2nTb8zo=
github.com/replicate/replicate-go v0.8.1/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/quicktemplate v1.7.0/go.mod h1:sqKJnoaOF88V07vkO+9FL8fb9uZg/VPSJnLYn+LmLk8=
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go-simpler.org/assert v0.6.0/go.mod h1:74Eqh5eI6vCK6Y5l3PI8ZYFXG4Sa+tkr70OIPJAUr28=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
70 changes: 70 additions & 0 deletions internal/cmd/hardware/list.go
@@ -0,0 +1,70 @@
package hardware

import (
"encoding/json"
"fmt"

"github.com/cli/browser"
"github.com/spf13/cobra"

"github.com/replicate/cli/internal/util"
"github.com/replicate/replicate-go"
)

// listCmd represents the list hardware command
var listCmd = &cobra.Command{
Use: "list",
Short: "List hardware",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

if cmd.Flags().Changed("web") {
if util.IsTTY() {
fmt.Println("Opening in browser...")
}

url := "https://replicate.com/pricing#hardware"
err := browser.OpenURL(url)
if err != nil {
return fmt.Errorf("failed to open browser: %w", err)
}

return nil
}

client, err := replicate.NewClient(replicate.WithTokenFromEnv())
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}

hardware, err := client.ListHardware(ctx)
if err != nil {
return fmt.Errorf("failed to list hardware: %w", err)
}

if cmd.Flags().Changed("json") || !util.IsTTY() {
bytes, err := json.MarshalIndent(hardware, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal hardware: %w", err)
}
fmt.Println(string(bytes))
return nil
}

for _, hw := range *hardware {
fmt.Printf("- %s: %s\n", hw.SKU, hw.Name)
}

return nil
},
}

func init() {
addListFlags(listCmd)
}

func addListFlags(cmd *cobra.Command) {
cmd.Flags().Bool("json", false, "Emit JSON")
cmd.Flags().Bool("web", false, "View on web")
cmd.MarkFlagsMutuallyExclusive("json", "web")
}
24 changes: 24 additions & 0 deletions internal/cmd/hardware/root.go
@@ -0,0 +1,24 @@
package hardware

import (
"github.com/spf13/cobra"
)

var RootCmd = &cobra.Command{
Use: "hardware [subcommand]",
Short: "Interact with hardware",
Aliases: []string{"hw"},
}

func init() {
RootCmd.AddGroup(&cobra.Group{
ID: "subcommand",
Title: "Subcommands:",
})
for _, cmd := range []*cobra.Command{
listCmd,
} {
RootCmd.AddCommand(cmd)
cmd.GroupID = "subcommand"
}
}
110 changes: 110 additions & 0 deletions internal/cmd/model/create.go
@@ -0,0 +1,110 @@
package model

import (
"fmt"

"github.com/cli/browser"
"github.com/replicate/cli/internal/identifier"
"github.com/replicate/cli/internal/util"
"github.com/replicate/replicate-go"
"github.com/spf13/cobra"
)

// createCmd represents the create command
var createCmd = &cobra.Command{
Use: "create <owner>/<name> [flags]",
Short: "Create a new model",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
id, err := identifier.ParseIdentifier(args[0])
if err != nil || id.Version != "" {
return fmt.Errorf("expected <owner>/<name> but got %s", args[0])
}

opts := &replicate.CreateModelOptions{}
flags := cmd.Flags()

if flags.Changed("public") {
opts.Visibility, _ = flags.GetString("public")
} else if flags.Changed("private") {
opts.Visibility, _ = flags.GetString("private")
}

opts.Hardware, _ = flags.GetString("hardware")

flagMap := map[string]**string{
"description": &opts.Description,
"github-url": &opts.GithubURL,
"paper-url": &opts.PaperURL,
"license-url": &opts.LicenseURL,
"cover-image-url": &opts.CoverImageURL,
}
for flagName, optPtr := range flagMap {
if flags.Changed(flagName) {
value, _ := flags.GetString(flagName)
*optPtr = &value
}
}

client, err := replicate.NewClient(replicate.WithTokenFromEnv())
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}

model, err := client.CreateModel(cmd.Context(), id.Owner, id.Name, *opts)
if err != nil {
return fmt.Errorf("failed to create model: %w", err)
}

if flags.Changed("json") || !util.IsTTY() {
bytes, err := model.MarshalJSON()
if err != nil {
return fmt.Errorf("failed to serialize model: %w", err)
}
fmt.Println(string(bytes))
return nil
}

url := fmt.Sprintf("https://replicate.com/%s/%s", id.Owner, id.Name)
if flags.Changed("web") {
if util.IsTTY() {
fmt.Println("Opening in browser...")
}

err := browser.OpenURL(url)
if err != nil {
return fmt.Errorf("failed to open browser: %w", err)
}

return nil
}

fmt.Printf("Model created: %s\n", url)

return nil
},
}

func init() {
addCreateFlags(createCmd)
}

func addCreateFlags(cmd *cobra.Command) {
cmd.Flags().Bool("public", false, "Make the new model public")
cmd.Flags().Bool("private", false, "Make the new model private")
cmd.MarkFlagsOneRequired("public", "private")
cmd.MarkFlagsMutuallyExclusive("public", "private")

cmd.Flags().String("hardware", "", "SKU of the hardware to run the model")
_ = cmd.MarkFlagRequired("hardware")

cmd.Flags().String("description", "", "Description of the model")
cmd.Flags().String("github-url", "", "URL of the GitHub repository")
cmd.Flags().String("paper-url", "", "URL of the paper")
cmd.Flags().String("license-url", "", "URL of the license")
cmd.Flags().String("cover-image-url", "", "URL of the cover image")

cmd.Flags().Bool("json", false, "Emit JSON")
cmd.Flags().Bool("web", false, "View on web")
cmd.MarkFlagsMutuallyExclusive("json", "web")
}
1 change: 1 addition & 0 deletions internal/cmd/model/root.go
Expand Up @@ -18,6 +18,7 @@ func init() {
for _, cmd := range []*cobra.Command{
showCmd,
schemaCmd,
createCmd,
} {
RootCmd.AddCommand(cmd)
cmd.GroupID = "subcommand"
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/model/show.go
Expand Up @@ -12,7 +12,7 @@ import (
)

var showCmd = &cobra.Command{
Use: "show",
Use: "show <owner/model[:version]> [flags]",
Short: "Show a model",
Args: cobra.ExactArgs(1),
Aliases: []string{"view"},
Expand Down

0 comments on commit e63f2ae

Please sign in to comment.