Skip to content

Commit

Permalink
mod+server: Add seed in AI responses
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 8, 2024
1 parent 6e76966 commit 650aae7
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 95 deletions.
26 changes: 21 additions & 5 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToIm
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*TranscodeResult, error) {
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}

Expand Down Expand Up @@ -892,7 +892,7 @@ func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImage
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*TranscodeResult, error) {
func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1

Expand All @@ -911,8 +911,9 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo
clog.V(common.DEBUG).Infof(ctx, "Generating frames took=%v", took)

sessionID := string(RandomManifestID())
// HACK: Re-use worker.ImageResponse to return results
// Transcode frames into segments.
results := make([]*TranscodeResult, len(resp.Frames))
videos := make([]worker.Media, len(resp.Frames))
for i, batch := range resp.Frames {
// Create slice of frame urls for a batch
urls := make([]string, len(batch))
Expand All @@ -926,10 +927,25 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideo
return nil, res.Err
}

results[i] = res
// Assume only single rendition right now
seg := res.TranscodeData.Segments[0]
name := fmt.Sprintf("%v.mp4", RandomManifestID())
segData := bytes.NewReader(seg.Data)
uri, err := res.OS.SaveData(ctx, name, segData, nil, 0)
if err != nil {
return nil, err
}

videos[i] = worker.Media{
Url: uri,
}

if len(batch) > 0 {
videos[i].Seed = batch[0].Seed
}
}

return results, nil
return &worker.ImageResponse{Images: videos}, nil
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/golang/protobuf v1.5.3
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580
github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240120150405-de94555cdc69
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,10 @@ github.com/livepeer/ai-worker v0.0.0-20240202211855-823caeaa265f h1:8owDNiBfN0j6
github.com/livepeer/ai-worker v0.0.0-20240202211855-823caeaa265f/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4=
github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580 h1:7ACCHUpeJsoWADgST/nWfGD0LVRSXFcYG6FTGvzUGn4=
github.com/livepeer/ai-worker v0.0.0-20240205185039-5c4895915580/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4=
github.com/livepeer/ai-worker v0.0.0-20240207221157-87e4f48ec353 h1:Ee1+i+q1EpP9D3AOufAnMSyEP06zaRhcyMRfSk6GJF8=
github.com/livepeer/ai-worker v0.0.0-20240207221157-87e4f48ec353/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4=
github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40 h1:vVbuu5wqrzq6M6Rlutk0eZv6qZ/kO2OrqQv5n6yt57s=
github.com/livepeer/ai-worker v0.0.0-20240208153040-7c92507e2a40/go.mod h1:3+A2/SYTqs+551SKTPy20AVnB8b0Yp26Va5SY37eQ/4=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
57 changes: 4 additions & 53 deletions server/ai_http.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package server

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/getkin/kin-openapi/openapi3filter"
"github.com/golang/protobuf/proto"
"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/clog"
"github.com/livepeer/go-livepeer/common"
"github.com/livepeer/go-livepeer/core"
"github.com/livepeer/go-livepeer/net"
middleware "github.com/oapi-codegen/nethttp-middleware"
"github.com/oapi-codegen/runtime"
)
Expand Down Expand Up @@ -56,7 +51,7 @@ func (h *lphttp) TextToImage() http.Handler {
return
}

clog.V(common.VERBOSE).Infof(r.Context(), "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId)
clog.V(common.VERBOSE).Infof(ctx, "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId)

start := time.Now()
resp, err := h.orchestrator.TextToImage(r.Context(), req)
Expand Down Expand Up @@ -129,61 +124,17 @@ func (h *lphttp) ImageToVideo() http.Handler {
clog.V(common.VERBOSE).Infof(ctx, "Received ImageToVideo request imageSize=%v model_id=%v", req.Image.FileSize(), *req.ModelId)

start := time.Now()
results, err := h.orchestrator.ImageToVideo(ctx, req)
resp, err := h.orchestrator.ImageToVideo(ctx, req)
if err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

// TODO: Handle more than one video
if len(results) != 1 {
respondWithError(w, "failed to return results", http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.Infof(ctx, "Processed ImageToVideo request imageSize=%v model_id=%v took=%v", req.Image.FileSize(), *req.ModelId, took)

res := results[0]

// Assume only single rendition right now
seg := res.TranscodeData.Segments[0]
name := fmt.Sprintf("%v.mp4", core.RandomManifestID())
segData := bytes.NewReader(seg.Data)
uri, err := res.OS.SaveData(ctx, name, segData, nil, 0)
if err != nil {
clog.Errorf(ctx, "Could not upload segment err=%q", err)
}

var result net.TranscodeResult
if err != nil {
clog.Errorf(ctx, "Could not transcode err=%q", err)
result = net.TranscodeResult{Result: &net.TranscodeResult_Error{Error: err.Error()}}
} else {
result = net.TranscodeResult{
Result: &net.TranscodeResult_Data{
Data: &net.TranscodeData{
Segments: []*net.TranscodedSegmentData{
{Url: uri, Pixels: seg.Pixels},
},
Sig: res.Sig,
},
},
}
}

tr := &net.TranscodeResult{
Result: result.Result,
// TODO: Add other fields
}

buf, err := proto.Marshal(tr)
if err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(buf)
_ = json.NewEncoder(w).Encode(resp)
})
}
54 changes: 19 additions & 35 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"path/filepath"
"strings"
"time"

"github.com/cenkalti/backoff"
"github.com/golang/protobuf/proto"
"github.com/livepeer/ai-worker/worker"
"github.com/livepeer/go-livepeer/clog"
"github.com/livepeer/go-livepeer/common"
"github.com/livepeer/go-livepeer/core"
"github.com/livepeer/go-livepeer/net"
"github.com/livepeer/go-tools/drivers"
)

Expand Down Expand Up @@ -84,7 +82,7 @@ func processTextToImage(ctx context.Context, params aiRequestParams, req worker.
return nil, err
}

newMedia[i] = worker.Media{Url: newUrl}
newMedia[i] = worker.Media{Url: newUrl, Seed: media.Seed}
}

resp.Images = newMedia
Expand Down Expand Up @@ -155,7 +153,7 @@ func processImageToImage(ctx context.Context, params aiRequestParams, req worker
return nil, err
}

newMedia[i] = worker.Media{Url: newUrl}
newMedia[i] = worker.Media{Url: newUrl, Seed: media.Seed}
}

resp.Images = newMedia
Expand Down Expand Up @@ -204,10 +202,10 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker

orchUrl := orchInfos[0].Transcoder

var urls []string
var resp *worker.ImageResponse
op := func() error {
var err error
urls, err = submitImageToVideo(ctx, orchUrl, req)
resp, err = submitImageToVideo(ctx, orchUrl, req)
return err
}
notify := func(err error, dur time.Duration) {
Expand All @@ -220,42 +218,43 @@ func processImageToVideo(ctx context.Context, params aiRequestParams, req worker
}

// HACK: Re-use worker.ImageResponse to return results
videos := make([]worker.Media, len(urls))
for i, url := range urls {
data, err := downloadSeg(ctx, url)
videos := make([]worker.Media, len(resp.Images))
for i, media := range resp.Images {
data, err := downloadSeg(ctx, media.Url)
if err != nil {
return nil, err
}

name := filepath.Base(url)
name := filepath.Base(media.Url)
newUrl, err := params.os.SaveData(ctx, name, bytes.NewReader(data), nil, 0)
if err != nil {
return nil, err
}

videos[i] = worker.Media{
Url: newUrl,
Url: newUrl,
Seed: media.Seed,
}
}

resp := &worker.ImageResponse{Images: videos}
resp.Images = videos

return resp, nil
}

func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideoMultipartRequestBody) ([]string, error) {
func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
var buf bytes.Buffer
mw, err := worker.NewImageToVideoMultipartWriter(&buf, req)
if err != nil {
return nil, err
}

r, err := http.NewRequestWithContext(ctx, "POST", url+"/image-to-video", &buf)
client, err := worker.NewClientWithResponses(url, worker.WithHTTPClient(httpClient))
if err != nil {
return nil, err
}
r.Header.Set("Content-Type", mw.FormDataContentType())

resp, err := sendReqWithTimeout(r, imageToVideoTimeout)
resp, err := client.ImageToVideoWithBody(ctx, mw.FormDataContentType(), &buf)
if err != nil {
return nil, err
}
Expand All @@ -270,25 +269,10 @@ func submitImageToVideo(ctx context.Context, url string, req worker.ImageToVideo
return nil, errors.New(string(data))
}

var tr net.TranscodeResult
if err := proto.Unmarshal(data, &tr); err != nil {
var res worker.ImageResponse
if err := json.Unmarshal(data, &res); err != nil {
return nil, err
}

var tdata *net.TranscodeData
switch res := tr.Result.(type) {
case *net.TranscodeResult_Error:
return nil, errors.New(res.Error)
case *net.TranscodeResult_Data:
tdata = res.Data
default:
return nil, errors.New("UnknownResponse")
}

urls := make([]string, len(tdata.Segments))
for i, seg := range tdata.Segments {
urls[i] = seg.Url
}

return urls, nil
return &res, nil
}
2 changes: 1 addition & 1 deletion server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Orchestrator interface {
AuthToken(sessionID string, expiration int64) *net.AuthToken
TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) ([]*core.TranscodeResult, error)
ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down

0 comments on commit 650aae7

Please sign in to comment.