Skip to content

Commit

Permalink
Add fields to ApiError struct (#32)
Browse files Browse the repository at this point in the history
* Add fields to ApiError struct

* Add documentation comments
  • Loading branch information
mattt committed Dec 4, 2023
1 parent 95eec81 commit c5d8c36
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 34 deletions.
68 changes: 61 additions & 7 deletions apierror.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,78 @@ package replicate
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)

// APIError represents an error returned by the Replicate API
type APIError struct {
Detail string `json:"detail"`
// Type is a URI that identifies the error type.
Type string `json:"type,omitempty"`

// Title is a short human-readable summary of the error.
Title string `json:"title,omitempty"`

// Status is the HTTP status code.
Status int `json:"status,omitempty"`

// Detail is a human-readable explanation of the error.
Detail string `json:"detail,omitempty"`

// Instance is a URI that identifies the specific occurrence of the error.
Instance string `json:"instance,omitempty"`
}

func unmarshalAPIError(data []byte) *APIError {
apiError := &APIError{}
err := json.Unmarshal(data, apiError)
func unmarshalAPIError(resp *http.Response, data []byte) *APIError {
apiError := APIError{}
err := json.Unmarshal(data, &apiError)
if err != nil {
apiError.Detail = fmt.Sprintf("Unknown error: %s", err)
}

return apiError
if apiError.Status == 0 && resp != nil {
apiError.Status = resp.StatusCode
}

return &apiError
}

// Error implements the error interface
func (e APIError) Error() string {
return fmt.Sprintf("Replicate API error: %s", e.Detail)
components := []string{}
if e.Type != "" {
components = append(components, e.Type)
}

if e.Title != "" {
components = append(components, e.Title)
}

if e.Detail != "" {
components = append(components, e.Detail)
}

output := strings.Join(components, ": ")
if output == "" {
output = "Unknown error"
}

if e.Instance != "" {
output = fmt.Sprintf("%s (%s)", output, e.Instance)
}

return output
}

func (e *APIError) WriteHTTPResponse(w http.ResponseWriter) {
status := http.StatusBadGateway
if e.Status != 0 {
status = e.Status
}

w.WriteHeader(status)
err := json.NewEncoder(w).Encode(e)
if err != nil {
err = fmt.Errorf("failed to write error response: %w", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func (r *Client) request(ctx context.Context, method, path string, body interfac
}

if response.StatusCode < 200 || response.StatusCode >= 400 {
apiError = unmarshalAPIError(responseBytes)
apiError = unmarshalAPIError(response, responseBytes)
if !r.shouldRetry(response, method) {
return apiError
}
Expand Down
35 changes: 10 additions & 25 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/replicate/replicate-go"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -1148,19 +1149,11 @@ func TestAutomaticallyRetryGetRequests(t *testing.T) {
w.Header().Set("Retry-After", "0")
w.WriteHeader(status)

if status == http.StatusInternalServerError {
err := &replicate.APIError{
Detail: "Internal server error",
}
body, _ := json.Marshal(err)
w.Write(body)
} else if status == http.StatusTooManyRequests {
err := &replicate.APIError{
Detail: "Too many requests",
}
body, _ := json.Marshal(err)
w.Write(body)
err := replicate.APIError{
Detail: http.StatusText(status),
}
body, _ := json.Marshal(err)
w.Write(body)
}
}))
defer mockServer.Close()
Expand Down Expand Up @@ -1191,19 +1184,11 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) {
w.Header().Set("Retry-After", "0")
w.WriteHeader(status)

if status == http.StatusInternalServerError {
err := &replicate.APIError{
Detail: "Internal server error",
}
body, _ := json.Marshal(err)
w.Write(body)
} else if status == http.StatusTooManyRequests {
err := &replicate.APIError{
Detail: "Too many requests",
}
body, _ := json.Marshal(err)
w.Write(body)
err := replicate.APIError{
Detail: http.StatusText(status),
}
body, _ := json.Marshal(err)
w.Write(body)
}))
defer mockServer.Close()

Expand All @@ -1224,7 +1209,7 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) {
version := "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
_, err = client.CreatePrediction(ctx, version, input, &webhook, true)

assert.ErrorContains(t, err, "Internal server error")
assert.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError))
}

func TestStream(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, l

switch event.Type {
case "error":
errChan <- unmarshalAPIError([]byte(event.Data))
errChan <- unmarshalAPIError(nil, []byte(event.Data))
case "done":
close(done)
return
Expand Down

0 comments on commit c5d8c36

Please sign in to comment.