Skip to content

Commit

Permalink
server: Add AISessionManager
Browse files Browse the repository at this point in the history
For managing the sessions per AI capability + model ID in a way that is compatible with existing broadcast session code
  • Loading branch information
yondonfu committed Mar 11, 2024
1 parent cabb34a commit f1718fa
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 26 deletions.
5 changes: 3 additions & 2 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ func (ls *LivepeerServer) TextToImage() http.Handler {
clog.V(common.VERBOSE).Infof(r.Context(), "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
Expand Down
42 changes: 18 additions & 24 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ func (e *ServiceUnavailableError) Error() string {
}

type aiRequestParams struct {
node *core.LivepeerNode
os drivers.OSSession
node *core.LivepeerNode
os drivers.OSSession
sessManager *AISessionManager
}

func getOrchestratorsForAIRequest(ctx context.Context, params aiRequestParams, cap core.Capability, modelID string) ([]*net.OrchestratorInfo, error) {
Expand Down Expand Up @@ -85,37 +86,30 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker.
modelID = *req.ModelId
}

orchInfos, err := getOrchestratorsForAIRequest(ctx, params, core.Capability_TextToImage, modelID)
if err != nil {
return nil, err
}

if len(orchInfos) == 0 {
return nil, &ServiceUnavailableError{err: errors.New("no orchestrators available")}
}

var resp *worker.ImageResponse

// Round robin up to maxProcessingRetries times
orchIdx := 0
tries := 0
for tries < maxProcessingRetries {
orchUrl := orchInfos[orchIdx].Transcoder
sess, err := params.sessManager.Select(ctx, core.Capability_TextToImage, modelID)
if err != nil {
return nil, err
}

var err error
resp, err = submitTextToImage(ctx, orchUrl, req)
if sess == nil {
break
}

resp, err = submitTextToImage(ctx, params, sess, req)
if err == nil {
params.sessManager.Complete(ctx, sess)
break
}

clog.Infof(ctx, "Error submitting TextToImage request try=%v orch=%v err=%v", tries, orchUrl, err)
clog.Infof(ctx, "Error submitting TextToImage request try=%v orch=%v err=%v", tries, sess.Transcoder(), err)

params.sessManager.Remove(ctx, sess)

tries++
orchIdx++
// Wrap back around
if orchIdx >= len(orchInfos) {
orchIdx = 0
}
}

if resp == nil {
Expand Down Expand Up @@ -143,8 +137,8 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker.
return resp, nil
}

func submitTextToImage(ctx context.Context, url string, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
client, err := worker.NewClientWithResponses(url, worker.WithHTTPClient(httpClient))
func submitTextToImage(ctx context.Context, params aiRequestParams, sess *AISession, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit f1718fa

Please sign in to comment.