Skip to content

Commit

Permalink
Add support for /predictions/{id}/cancel endpoint (#46)
Browse files Browse the repository at this point in the history
* Add cancel prediction endpoint.

* Add documentation.

* Add test for CancelPrediction

* Reorder training methods

---------

Co-authored-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mehmettokgoz and mattt committed Feb 27, 2024
1 parent b0335b6 commit ca660fe
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
35 changes: 35 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,41 @@ func TestCreatePredictionWithModel(t *testing.T) {
assert.Equal(t, replicate.Starting, prediction.Status)
}

func TestCancelPrediction(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq/cancel", r.URL.Path)

response := replicate.Prediction{
ID: "ufawqhfynnddngldkgtslldrkq",
Status: replicate.Canceled,
}
responseBytes, err := json.Marshal(response)
if err != nil {
t.Fatal(err)
}

w.WriteHeader(http.StatusOK)
w.Write(responseBytes)
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NotNil(t, client)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

prediction, err := client.CancelPrediction(ctx, "ufawqhfynnddngldkgtslldrkq")
assert.NoError(t, err)
assert.Equal(t, "ufawqhfynnddngldkgtslldrkq", prediction.ID)
assert.Equal(t, replicate.Canceled, prediction.Status)
}

func TestPredictionProgress(t *testing.T) {
prediction := replicate.Prediction{
ID: "ufawqhfynnddngldkgtslldrkq",
Expand Down
10 changes: 10 additions & 0 deletions prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,13 @@ func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, err
}
return prediction, nil
}

// CancelPrediction cancels a running prediction by its ID.
func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error) {
prediction := &Prediction{}
err := r.fetch(ctx, "POST", fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction)
if err != nil {
return nil, fmt.Errorf("failed to cancel prediction: %w", err)
}
return prediction, nil
}
20 changes: 10 additions & 10 deletions training.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ func (r *Client) CreateTraining(ctx context.Context, model_owner string, model_n
return training, nil
}

// ListTrainings returns a list of trainings.
func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) {
response := &Page[Training]{}
err := r.fetch(ctx, "GET", "/trainings", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list trainings: %w", err)
}
return response, nil
}

// GetTraining sends a request to the Replicate API to get a training.
func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error) {
training := &Training{}
Expand All @@ -52,13 +62,3 @@ func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Traini

return training, nil
}

// ListTrainings returns a list of trainings.
func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) {
response := &Page[Training]{}
err := r.fetch(ctx, "GET", "/trainings", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list trainings: %w", err)
}
return response, nil
}

0 comments on commit ca660fe

Please sign in to comment.