Skip to content

Commit

Permalink
Add support for deployments.get endpoint (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Feb 16, 2024
1 parent 157dd19 commit 48b7cbb
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
66 changes: 66 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,3 +1653,69 @@ func TestValidateWebhook(t *testing.T) {
require.NoError(t, err)
assert.True(t, isValid)
}

func TestGetDeployment(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/deployments/acme/image-upscaler", r.URL.Path)
assert.Equal(t, http.MethodGet, r.Method)

deployment := &replicate.Deployment{
Owner: "acme",
Name: "image-upscaler",
CurrentRelease: replicate.DeploymentRelease{
Number: 1,
Model: "acme/esrgan",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
CreatedAt: "2022-01-01T00:00:00Z",
CreatedBy: replicate.Account{
Type: "organization",
Username: "acme",
Name: "Acme, Inc.",
},
Configuration: replicate.DeploymentConfiguration{
Hardware: "gpu-t4",
MinInstances: 1,
MaxInstances: 5,
},
},
}

responseBytes, err := json.Marshal(deployment)
if err != nil {
t.Fatal(err)
}

w.WriteHeader(http.StatusOK)
w.Write(responseBytes)
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NotNil(t, client)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

deployment, err := client.GetDeployment(ctx, "acme", "image-upscaler")
if err != nil {
t.Fatal(err)
}

assert.NotNil(t, deployment)
assert.Equal(t, "acme", deployment.Owner)
assert.Equal(t, "image-upscaler", deployment.Name)
assert.Equal(t, 1, deployment.CurrentRelease.Number)
assert.Equal(t, "acme/esrgan", deployment.CurrentRelease.Model)
assert.Equal(t, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", deployment.CurrentRelease.Version)
assert.Equal(t, "2022-01-01T00:00:00Z", deployment.CurrentRelease.CreatedAt)
assert.Equal(t, "organization", deployment.CurrentRelease.CreatedBy.Type)
assert.Equal(t, "acme", deployment.CurrentRelease.CreatedBy.Username)
assert.Equal(t, "Acme, Inc.", deployment.CurrentRelease.CreatedBy.Name)
assert.Equal(t, "gpu-t4", deployment.CurrentRelease.Configuration.Hardware)
assert.Equal(t, 1, deployment.CurrentRelease.Configuration.MinInstances)
assert.Equal(t, 5, deployment.CurrentRelease.Configuration.MaxInstances)
}
52 changes: 52 additions & 0 deletions deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,61 @@ package replicate

import (
"context"
"encoding/json"
"fmt"
)

type Deployment struct {
Owner string `json:"owner"`
Name string `json:"name"`
CurrentRelease DeploymentRelease `json:"current_release"`

rawJSON json.RawMessage `json:"-"`
}

type DeploymentRelease struct {
Number int `json:"number"`
Model string `json:"model"`
Version string `json:"version"`
CreatedAt string `json:"created_at"`
CreatedBy Account `json:"created_by"`
Configuration DeploymentConfiguration `json:"configuration"`
}

type DeploymentConfiguration struct {
Hardware string `json:"hardware"`
MinInstances int `json:"min_instances"`
MaxInstances int `json:"max_instances"`
}

func (d Deployment) MarshalJSON() ([]byte, error) {
if d.rawJSON != nil {
return d.rawJSON, nil
} else {
type Alias Deployment
return json.Marshal(&struct{ *Alias }{Alias: (*Alias)(&d)})
}
}

func (d *Deployment) UnmarshalJSON(data []byte) error {
d.rawJSON = data
type Alias Deployment
alias := &struct{ *Alias }{Alias: (*Alias)(d)}
return json.Unmarshal(data, alias)
}

// GetDeployment retrieves the details of a specific deployment.
func (r *Client) GetDeployment(ctx context.Context, deployment_owner string, deployment_name string) (*Deployment, error) {
deployment := &Deployment{}
path := fmt.Sprintf("/deployments/%s/%s", deployment_owner, deployment_name)
err := r.fetch(ctx, "GET", path, nil, deployment)
if err != nil {
return nil, fmt.Errorf("failed to get deployment: %w", err)
}

return deployment, nil
}

// CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment.
func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deployment_owner string, deployment_name string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) {
data := map[string]interface{}{
Expand Down

0 comments on commit 48b7cbb

Please sign in to comment.