Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ai video fix selection pr #3033

Draft
wants to merge 4 commits into
base: ai-video
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/livepeer/livepeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func parseLivepeerConfig() starter.LivepeerConfig {
cfg.AIWorker = flag.Bool("aiWorker", *cfg.AIWorker, "Set to true to run an AI worker")
cfg.AIModels = flag.String("aiModels", *cfg.AIModels, "Set models (pipeline:model_id) for AI worker to load upon initialization")
cfg.AIModelsDir = flag.String("aiModelsDir", *cfg.AIModelsDir, "Set directory where AI model weights are stored")
cfg.AIWorkerNoManagedContainers = flag.Bool("aiWorkerNoManagedContainers", *cfg.AIWorkerNoManagedContainers, "set to true if want to no use managed containers with AI worker")

// Onchain:
cfg.EthAcctAddr = flag.String("ethAcctAddr", *cfg.EthAcctAddr, "Existing Eth account address. For use when multiple ETH accounts exist in the keystore directory")
Expand Down
155 changes: 79 additions & 76 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,78 +77,79 @@
)

type LivepeerConfig struct {
Network *string
RtmpAddr *string
CliAddr *string
HttpAddr *string
ServiceAddr *string
OrchAddr *string
VerifierURL *string
EthController *string
VerifierPath *string
LocalVerify *bool
HttpIngest *bool
Orchestrator *bool
Transcoder *bool
AIWorker *bool
Broadcaster *bool
OrchSecret *string
TranscodingOptions *string
AIModels *string
MaxAttempts *int
SelectRandWeight *float64
SelectStakeWeight *float64
SelectPriceWeight *float64
SelectPriceExpFactor *float64
OrchPerfStatsURL *string
Region *string
MaxPricePerUnit *int
MinPerfScore *float64
MaxSessions *string
CurrentManifest *bool
Nvidia *string
Netint *string
TestTranscoder *bool
EthAcctAddr *string
EthPassword *string
EthKeystorePath *string
EthOrchAddr *string
EthUrl *string
TxTimeout *time.Duration
MaxTxReplacements *int
GasLimit *int
MinGasPrice *int64
MaxGasPrice *int
InitializeRound *bool
TicketEV *string
MaxFaceValue *string
MaxTicketEV *string
MaxTotalEV *string
DepositMultiplier *int
PricePerUnit *int
PixelsPerUnit *int
AutoAdjustPrice *bool
PricePerBroadcaster *string
BlockPollingInterval *int
Redeemer *bool
RedeemerAddr *string
Reward *bool
Monitor *bool
MetricsPerStream *bool
MetricsExposeClientIP *bool
MetadataQueueUri *string
MetadataAmqpExchange *string
MetadataPublishTimeout *time.Duration
Datadir *string
AIModelsDir *string
Objectstore *string
Recordstore *string
FVfailGsBucket *string
FVfailGsKey *string
AuthWebhookURL *string
OrchWebhookURL *string
OrchBlacklist *string
TestOrchAvail *bool
Network *string
RtmpAddr *string
CliAddr *string
HttpAddr *string
ServiceAddr *string
OrchAddr *string
VerifierURL *string
EthController *string
VerifierPath *string
LocalVerify *bool
HttpIngest *bool
Orchestrator *bool
Transcoder *bool
AIWorker *bool
Broadcaster *bool
OrchSecret *string
TranscodingOptions *string
AIModels *string
AIWorkerNoManagedContainers *bool
MaxAttempts *int
SelectRandWeight *float64
SelectStakeWeight *float64
SelectPriceWeight *float64
SelectPriceExpFactor *float64
OrchPerfStatsURL *string
Region *string
MaxPricePerUnit *int
MinPerfScore *float64
MaxSessions *string
CurrentManifest *bool
Nvidia *string
Netint *string
TestTranscoder *bool
EthAcctAddr *string
EthPassword *string
EthKeystorePath *string
EthOrchAddr *string
EthUrl *string
TxTimeout *time.Duration
MaxTxReplacements *int
GasLimit *int
MinGasPrice *int64
MaxGasPrice *int
InitializeRound *bool
TicketEV *string
MaxFaceValue *string
MaxTicketEV *string
MaxTotalEV *string
DepositMultiplier *int
PricePerUnit *int
PixelsPerUnit *int
AutoAdjustPrice *bool
PricePerBroadcaster *string
BlockPollingInterval *int
Redeemer *bool
RedeemerAddr *string
Reward *bool
Monitor *bool
MetricsPerStream *bool
MetricsExposeClientIP *bool
MetadataQueueUri *string
MetadataAmqpExchange *string
MetadataPublishTimeout *time.Duration
Datadir *string
AIModelsDir *string
Objectstore *string
Recordstore *string
FVfailGsBucket *string
FVfailGsKey *string
AuthWebhookURL *string
OrchWebhookURL *string
OrchBlacklist *string
TestOrchAvail *bool
}

// DefaultLivepeerConfig creates LivepeerConfig exactly the same as when no flags are passed to the livepeer process.
Expand Down Expand Up @@ -187,6 +188,7 @@
defaultAIWorker := false
defaultAIModels := ""
defaultAIModelsDir := ""
defaultAIWorkerNoManagedContainers := false

// Onchain:
defaultEthAcctAddr := ""
Expand Down Expand Up @@ -273,9 +275,10 @@
TestTranscoder: &defaultTestTranscoder,

// AI:
AIWorker: &defaultAIWorker,
AIModels: &defaultAIModels,
AIModelsDir: &defaultAIModelsDir,
AIWorker: &defaultAIWorker,
AIModels: &defaultAIModels,
AIModelsDir: &defaultAIModelsDir,
AIWorkerNoManagedContainers: &defaultAIWorkerNoManagedContainers,

// Onchain:
EthAcctAddr: &defaultEthAcctAddr,
Expand Down Expand Up @@ -525,7 +528,7 @@
return
}

n.AIWorker, err = worker.NewWorker(aiWorkerContainerImageID, gpus, modelsDir)
n.AIWorker, err = worker.NewWorker(aiWorkerContainerImageID, gpus, modelsDir, *cfg.AIWorkerNoManagedContainers)

Check failure on line 531 in cmd/livepeer/starter/starter.go

View workflow job for this annotation

GitHub Actions / Build binaries for linux-cpu-amd64

too many arguments in call to worker.NewWorker

Check failure on line 531 in cmd/livepeer/starter/starter.go

View workflow job for this annotation

GitHub Actions / Build binaries for darwin-arm64

too many arguments in call to worker.NewWorker

Check failure on line 531 in cmd/livepeer/starter/starter.go

View workflow job for this annotation

GitHub Actions / Build binaries for linux-gpu-amd64

too many arguments in call to worker.NewWorker
if err != nil {
glog.Errorf("Error starting AI worker: %v", err)
return
Expand Down
49 changes: 38 additions & 11 deletions server/ai_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ type AISessionPool struct {
sessMap map[string]*BroadcastSession
inUseSess []*BroadcastSession
suspender *suspender
penalty int
mu sync.RWMutex
}

func NewAISessionPool(selector BroadcastSessionsSelector, suspender *suspender) *AISessionPool {
func NewAISessionPool(selector BroadcastSessionsSelector, suspender *suspender, penalty int) *AISessionPool {
return &AISessionPool{
selector: selector,
sessMap: make(map[string]*BroadcastSession),
suspender: suspender,
penalty: penalty,
mu: sync.RWMutex{},
}
}
Expand Down Expand Up @@ -101,10 +103,6 @@ func (pool *AISessionPool) Add(sessions []*BroadcastSession) {
pool.mu.Lock()
defer pool.mu.Unlock()

// If we try to add new sessions to the pool the suspender
// should treat this as a refresh
pool.suspender.signalRefresh()

var uniqueSessions []*BroadcastSession
for _, sess := range sessions {
if _, ok := pool.sessMap[sess.Transcoder()]; ok {
Expand All @@ -126,10 +124,17 @@ func (pool *AISessionPool) Remove(sess *BroadcastSession) {
delete(pool.sessMap, sess.Transcoder())
pool.inUseSess = removeSessionFromList(pool.inUseSess, sess)

// Magic number for now
penalty := 3
penalty := 0
// If this method is called assume that the orch should be suspended
// as well
// as well. Since AISessionManager re-uses the pools the suspension
// penalty needs to consider the current suspender count to set the penalty
last_count, ok := pool.suspender.list[sess.Transcoder()]
if ok {
penalty = pool.suspender.count - last_count + pool.penalty
} else {
penalty = pool.suspender.count + pool.penalty
}

pool.suspender.suspend(sess.Transcoder(), penalty)
}

Expand All @@ -156,12 +161,14 @@ type AISessionSelector struct {
// The time until the pools should be refreshed with orchs from discovery
ttl time.Duration
lastRefreshTime time.Time
initialPoolSize int

cap core.Capability
modelID string

node *core.LivepeerNode
suspender *suspender
penalty int
os drivers.OSSession
}

Expand All @@ -176,8 +183,9 @@ func NewAISessionSelector(cap core.Capability, modelID string, node *core.Livepe
// The latency score in this context is just the latency of the last completed request for a session
// The "good enough" latency score is set to 0.0 so the selector will always select unknown sessions first
minLS := 0.0
warmPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender)
coldPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender)
penalty := 3
warmPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender, penalty)
coldPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore), suspender, penalty)
sel := &AISessionSelector{
warmPool: warmPool,
coldPool: coldPool,
Expand All @@ -186,6 +194,7 @@ func NewAISessionSelector(cap core.Capability, modelID string, node *core.Livepe
modelID: modelID,
node: node,
suspender: suspender,
penalty: penalty,
os: drivers.NodeStorage.NewSession(strconv.Itoa(int(cap)) + "_" + modelID),
}

Expand All @@ -200,7 +209,17 @@ func (sel *AISessionSelector) Select(ctx context.Context) *AISession {
shouldRefreshSelector := func() bool {
// Refresh if the # of sessions across warm and cold pools falls below the smaller of the maxRefreshSessionsThreshold and
// 1/2 the total # of orchs that can be queried during discovery
discoveryPoolSize := sel.node.OrchestratorPool.Size()
discoveryPoolSize := int(math.Min(float64(sel.node.OrchestratorPool.Size()), float64(sel.initialPoolSize)))

if (sel.warmPool.Size() + sel.coldPool.Size()) == 0 {
//release all orchestrators from suspension and try refresh
//if penalty in
clog.Infof(ctx, "refreshing sessions, no orchestrators in pools")
for i := 0; i < sel.penalty; i++ {
sel.suspender.signalRefresh()
}
}

if sel.warmPool.Size()+sel.coldPool.Size() < int(math.Min(maxRefreshSessionsThreshold, math.Ceil(float64(discoveryPoolSize)/2.0))) {
return true
}
Expand Down Expand Up @@ -250,13 +269,18 @@ func (sel *AISessionSelector) Remove(sess *AISession) {
}

func (sel *AISessionSelector) Refresh(ctx context.Context) error {
// If we try to add new sessions to the pool the suspender
// should treat this as a refresh
sel.suspender.signalRefresh()

sessions, err := sel.getSessions(ctx)
if err != nil {
return err
}

var warmSessions []*BroadcastSession
var coldSessions []*BroadcastSession

for _, sess := range sessions {
// If the constraints are missing for this capability skip this session
constraints, ok := sess.OrchestratorInfo.Capabilities.Constraints[uint32(sel.cap)]
Expand All @@ -279,6 +303,7 @@ func (sel *AISessionSelector) Refresh(ctx context.Context) error {

sel.warmPool.Add(warmSessions)
sel.coldPool.Add(coldSessions)
sel.initialPoolSize = len(warmSessions) + len(coldSessions) + len(sel.suspender.list)

sel.lastRefreshTime = time.Now()

Expand Down Expand Up @@ -353,6 +378,8 @@ func (c *AISessionManager) Select(ctx context.Context, cap core.Capability, mode
}
}

clog.Infof(ctx, "session selected orchestrator=%s", sess.Transcoder())

return sess, nil
}

Expand Down