Skip to content

Commit

Permalink
Support predictions.create with model, version, or deployment
Browse files Browse the repository at this point in the history
… parameters (#290)

This PR updates `predictions.create` to support overloads with `model`,
`version`, or `deployment` parameters. With these changes, API consumers
can more easily change between official models, model versions, and
deployments.

```python
import replicate

prediction = replicate.predictions.create(
    model="meta/meta-llama-3-8b-instruct",
    input={"prompt": "write a haiku about corgis"},
)

prediction = replicate.predictions.create(
    version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={"prompt": "a studio photo of a rainbow colored corgi"},
)

prediction = replicate.predictions.create(
    deployment="my-username/my-embeddings-model",
    input={"text": "hello world"},
)
```

---------

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
mattt committed May 7, 2024
1 parent c0cf1ec commit eb3cebf
Show file tree
Hide file tree
Showing 17 changed files with 3,509 additions and 1,972 deletions.
130 changes: 128 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
List,
Literal,
Optional,
Tuple,
Union,
overload,
)

from typing_extensions import NotRequired, TypedDict, Unpack
Expand All @@ -31,6 +33,8 @@

if TYPE_CHECKING:
from replicate.client import Client
from replicate.deployment import Deployment
from replicate.model import Model
from replicate.stream import ServerSentEvent


Expand Down Expand Up @@ -380,21 +384,82 @@ class CreatePredictionParams(TypedDict):
stream: NotRequired[bool]
"""Enable streaming of prediction output."""

@overload
def create(
self,
version: Union[Version, str],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

@overload
def create(
self,
*,
model: Union[str, Tuple[str, str], "Model"],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

@overload
def create(
self,
*,
deployment: Union[str, Tuple[str, str], "Deployment"],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

def create( # type: ignore
self,
*args,
model: Optional[Union[str, Tuple[str, str], "Model"]] = None,
version: Optional[Union[Version, str, "Version"]] = None,
deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction for the specified model version.
Create a new prediction for the specified model, version, or deployment.
"""

if args:
version = args[0] if len(args) > 0 else None
input = args[1] if len(args) > 1 else input

if sum(bool(x) for x in [model, version, deployment]) != 1:
raise ValueError(
"Exactly one of 'model', 'version', or 'deployment' must be specified."
)

if model is not None:
from replicate.model import ( # pylint: disable=import-outside-toplevel
Models,
)

return Models(self._client).predictions.create(
model=model,
input=input or {},
**params,
)

if deployment is not None:
from replicate.deployment import ( # pylint: disable=import-outside-toplevel
Deployments,
)

return Deployments(self._client).predictions.create(
deployment=deployment,
input=input or {},
**params,
)

body = _create_prediction_body(
version,
input,
**params,
)

resp = self._client._request(
"POST",
"/v1/predictions",
Expand All @@ -403,21 +468,82 @@ def create(

return _json_to_prediction(self._client, resp.json())

@overload
async def async_create(
self,
version: Union[Version, str],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

@overload
async def async_create(
self,
*,
model: Union[str, Tuple[str, str], "Model"],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

@overload
async def async_create(
self,
*,
deployment: Union[str, Tuple[str, str], "Deployment"],
input: Optional[Dict[str, Any]],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction: ...

async def async_create( # type: ignore
self,
*args,
model: Optional[Union[str, Tuple[str, str], "Model"]] = None,
version: Optional[Union[Version, str, "Version"]] = None,
deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction for the specified model version.
Create a new prediction for the specified model, version, or deployment.
"""

if args:
version = args[0] if len(args) > 0 else None
input = args[1] if len(args) > 1 else input

if sum(bool(x) for x in [model, version, deployment]) != 1:
raise ValueError(
"Exactly one of 'model', 'version', or 'deployment' must be specified."
)

if model is not None:
from replicate.model import ( # pylint: disable=import-outside-toplevel
Models,
)

return await Models(self._client).predictions.async_create(
model=model,
input=input or {},
**params,
)

if deployment is not None:
from replicate.deployment import ( # pylint: disable=import-outside-toplevel
Deployments,
)

return await Deployments(self._client).predictions.async_create(
deployment=deployment,
input=input or {},
**params,
)

body = _create_prediction_body(
version,
input,
**params,
)

resp = await self._client._async_request(
"POST",
"/v1/predictions",
Expand Down

0 comments on commit eb3cebf

Please sign in to comment.