Skip to content

Commit

Permalink
Add stream command (#42)
Browse files Browse the repository at this point in the history
* Update github.com/replicate/replicate-go

* First pass at implementing stream command

* Fallback to creating a prediction on latest version of model

* Change GetSchemas to return *openapi3.Schema

* Update github.com/replicate/replicate-go

* Remove unused import

* Implement prediction create --stream option

* Implement stream as alias to replicate prediction create --stream

* Update README

* Fix version fallback logic for creating a prediction

* Update demo tape
  • Loading branch information
mattt committed Dec 4, 2023
1 parent 68bf642 commit e14440a
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 94 deletions.
15 changes: 15 additions & 0 deletions README.md
Expand Up @@ -43,6 +43,7 @@ Core commands:

Alias commands:
run Alias for "prediction create"
stream Alias for "prediction create --stream"
train Alias for "training create"

Additional Commands:
Expand All @@ -67,6 +68,20 @@ $ replicate run stability-ai/sdxl \
Prediction created: https://replicate.com/p/jpgp263bdekvxileu2ppsy46v4
```

### Stream prediction output

Run [LLaMA 2] and stream output tokens to your terminal.

```console
$ replicate stream meta/llama-2-70b-chat \
prompt="Tell me a joke about llamas"
Sure, here's a joke about llamas for you:

Why did the llama refuse to play poker?

Because he always got fleeced!
```

### Create a local development environment from a prediction

Create a Node.js or Python project from a prediction.
Expand Down
1 change: 1 addition & 0 deletions cmd/replicate/main.go
Expand Up @@ -50,6 +50,7 @@ func init() {
for _, cmd := range []*cobra.Command{
cmd.RunCmd,
cmd.TrainCmd,
cmd.StreamCmd,
} {
rootCmd.AddCommand(cmd)
cmd.GroupID = "alias"
Expand Down
Binary file modified demo.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 11 additions & 1 deletion demo.tape
Expand Up @@ -17,11 +17,21 @@ Sleep 100ms
Ctrl+C # Don't actually set the API key
Sleep 1s

Type 'replicate stream meta/llama-2-70b-chat \'
Enter
Type@50ms ' prompt="write a haiku about corgis"'
Enter

Sleep 2s

Enter

Enter
Type 'replicate run stability-ai/sdxl \'
Enter
Type@50ms ' prompt="a studio photo of a rainbow colored corgi" \'
Enter
Type@50ms ' width=512 height=512 seed=42069'
Enter

Sleep 10s
Sleep 15s
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -14,7 +14,7 @@ require (
github.com/getkin/kin-openapi v0.120.0
github.com/golangci/golangci-lint v1.55.2
github.com/mattn/go-isatty v0.0.20
github.com/replicate/replicate-go v0.12.0
github.com/replicate/replicate-go v0.13.2
github.com/schollz/progressbar/v3 v3.13.1
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Expand Up @@ -522,8 +522,12 @@ github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 h1:TCg2WBOl
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727/go.mod h1:rlzQ04UMyJXu/aOvhd8qT+hvDrFpiwqp8MRXDY9szc0=
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 h1:M8mH9eK4OUR4lu7Gd+PU1fV2/qnDNfzT635KRSObncs=
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8A4Y+GyBgPuaQJuWiy0XYftx4Xm/y5Jqk9I6VQ=
github.com/replicate/replicate-go v0.12.0 h1:gd/hw4hCBO5G4M3Fezb3zdKYSbe9NEfRLzGoktFk3Ks=
github.com/replicate/replicate-go v0.12.0/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.13.0 h1:DWpSw8ck+dVK79jcVbg0iWJt4/ajcDaYX7FmiqSh2iI=
github.com/replicate/replicate-go v0.13.0/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.13.1 h1:+WgP8hoWuw8e0ZCA1RlVQZZrDkaksMZJCk8C1i+icp0=
github.com/replicate/replicate-go v0.13.1/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.13.2 h1:S+ENs0cKMlizZzh9Ht/Diy66FCPKSdXNRh/9QvKyf+8=
github.com/replicate/replicate-go v0.13.2/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
Expand Down
2 changes: 2 additions & 0 deletions go.work.sum
Expand Up @@ -33,6 +33,8 @@ github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQ
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71/go.mod h1:4cgAphtvu7Ftv7vOT2ZOYhC6CvBxZixcasr8qIOTA50=
github.com/replicate/replicate-go v0.8.1 h1:Mza5hWR/R1akZRKwXtA/CQJ2pY4/B9fSCYX+2nTb8zo=
github.com/replicate/replicate-go v0.8.1/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/replicate/replicate-go v0.13.0 h1:DWpSw8ck+dVK79jcVbg0iWJt4/ajcDaYX7FmiqSh2iI=
github.com/replicate/replicate-go v0.13.0/go.mod h1:k9C4+PaYa9+hMRjn4D7ZPHOCUFb8P4jhytsCqcGa2vU=
github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
Expand Down
10 changes: 5 additions & 5 deletions internal/cmd/model/schema.go
Expand Up @@ -71,8 +71,8 @@ func printModelVersionSchema(version *replicate.ModelVersion) error {
if inputSchema != nil {
fmt.Println("Inputs:")

for _, propName := range util.SortedKeys(inputSchema.Value.Properties) {
prop, ok := inputSchema.Value.Properties[propName]
for _, propName := range util.SortedKeys(inputSchema.Properties) {
prop, ok := inputSchema.Properties[propName]
if !ok {
continue
}
Expand All @@ -91,9 +91,9 @@ func printModelVersionSchema(version *replicate.ModelVersion) error {

if outputSchema != nil {
fmt.Println("Output:")
fmt.Printf("- type: %s\n", outputSchema.Value.Type)
if outputSchema.Value.Type == "array" {
fmt.Printf("- items: %s %s\n", outputSchema.Value.Items.Value.Type, outputSchema.Value.Items.Value.Format)
fmt.Printf("- type: %s\n", outputSchema.Type)
if outputSchema.Type == "array" {
fmt.Printf("- items: %s %s\n", outputSchema.Items.Value.Type, outputSchema.Items.Value.Format)
}
fmt.Println()
}
Expand Down

0 comments on commit e14440a

Please sign in to comment.