Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add config option to enforce a single party per user socket #1179

Merged
merged 4 commits into from May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
### Added
- Add runtime support for registering a shutdown hook function.
- Add support to custom sorting in storage index search.
- New config options to enforce a single party per user socket.

### Changed
- When a user is blocked, any DM streams between the blocker and blocked user are torn down.
Expand Down
2 changes: 1 addition & 1 deletion main.go
Expand Up @@ -166,7 +166,7 @@ func main() {
startupLogger.Fatal("Failed initializing runtime modules", zap.Error(err))
}
matchmaker := server.NewLocalMatchmaker(logger, startupLogger, config, router, metrics, runtime)
partyRegistry := server.NewLocalPartyRegistry(logger, matchmaker, tracker, streamManager, router, config.GetName())
partyRegistry := server.NewLocalPartyRegistry(logger, config, matchmaker, tracker, streamManager, router, config.GetName())
tracker.SetPartyJoinListener(partyRegistry.Join)
tracker.SetPartyLeaveListener(partyRegistry.Leave)

Expand Down
4 changes: 4 additions & 0 deletions server/config.go
Expand Up @@ -152,6 +152,9 @@ func CheckConfig(logger *zap.Logger, config Config) map[string]string {
if config.GetSession().SingleMatch && !config.GetSession().SingleSocket {
logger.Fatal("Single match cannot be enabled without single socket", zap.Strings("param", []string{"session.single_match", "session.single_socket"}))
}
if config.GetSession().SingleParty && !config.GetSession().SingleSocket {
logger.Fatal("Single party cannot be enabled without single socket", zap.Strings("param", []string{"session.single_party", "session.single_socket"}))
}
if config.GetRuntime().HTTPKey == "" {
logger.Fatal("Runtime HTTP key must be set", zap.String("param", "runtime.http_key"))
}
Expand Down Expand Up @@ -682,6 +685,7 @@ type SessionConfig struct {
RefreshTokenExpirySec int64 `yaml:"refresh_token_expiry_sec" json:"refresh_token_expiry_sec" usage:"Refresh token expiry in seconds."`
SingleSocket bool `yaml:"single_socket" json:"single_socket" usage:"Only allow one socket per user. Older sessions are disconnected. Default false."`
SingleMatch bool `yaml:"single_match" json:"single_match" usage:"Only allow one match per user. Older matches receive a leave. Requires single socket to enable. Default false."`
SingleParty bool `yaml:"single_party" json:"single_party" usage:"Only allow one party per user. Older parties receive a leave. Requires single socket to enable. Default false."`
}

func NewSessionConfig() *SessionConfig {
Expand Down
7 changes: 6 additions & 1 deletion server/party_handler.go
Expand Up @@ -372,7 +372,7 @@ func (p *PartyHandler) Promote(sessionID, node string, presence *rtapi.UserPrese
return nil
}

func (p *PartyHandler) Accept(sessionID, node string, presence *rtapi.UserPresence) error {
func (p *PartyHandler) Accept(sessionID, node string, presence *rtapi.UserPresence, singleParty bool) error {
p.Lock()
if p.stopped {
p.Unlock()
Expand Down Expand Up @@ -427,6 +427,11 @@ func (p *PartyHandler) Accept(sessionID, node string, presence *rtapi.UserPresen
return err
}

if singleParty {
// Kick the user from any other parties they may be part of.
p.tracker.UntrackLocalByModes(joinRequestPresence.ID.SessionID, partyStreamMode, p.Stream)
}

// The party membership has changed, stop any ongoing matchmaking processes.
_ = p.matchmaker.RemovePartyAll(p.IDStr)

Expand Down
3 changes: 2 additions & 1 deletion server/party_handler_test.go
Expand Up @@ -61,9 +61,10 @@ func createTestPartyHandler(t *testing.T, logger *zap.Logger) (*PartyHandler, fu
mm, cleanup, _ := createTestMatchmaker(t, logger, true, nil)
tt := testTracker{}
tsm := testStreamManager{}

dmr := DummyMessageRouter{}

pr := NewLocalPartyRegistry(logger, mm, &tt, &tsm, &dmr, node)
pr := NewLocalPartyRegistry(logger, cfg, mm, &tt, &tsm, &dmr, node)
ph := NewPartyHandler(logger, pr, mm, &tt, &tsm, &dmr, uuid.UUID{}, node, true, 10, nil)
return ph, cleanup
}
6 changes: 4 additions & 2 deletions server/party_registry.go
Expand Up @@ -45,6 +45,7 @@ type PartyRegistry interface {

type LocalPartyRegistry struct {
logger *zap.Logger
config Config
matchmaker Matchmaker
tracker Tracker
streamManager StreamManager
Expand All @@ -54,9 +55,10 @@ type LocalPartyRegistry struct {
parties *MapOf[uuid.UUID, *PartyHandler]
}

func NewLocalPartyRegistry(logger *zap.Logger, matchmaker Matchmaker, tracker Tracker, streamManager StreamManager, router MessageRouter, node string) PartyRegistry {
func NewLocalPartyRegistry(logger *zap.Logger, config Config, matchmaker Matchmaker, tracker Tracker, streamManager StreamManager, router MessageRouter, node string) PartyRegistry {
return &LocalPartyRegistry{
logger: logger,
config: config,
matchmaker: matchmaker,
tracker: tracker,
streamManager: streamManager,
Expand Down Expand Up @@ -132,7 +134,7 @@ func (p *LocalPartyRegistry) PartyAccept(ctx context.Context, id uuid.UUID, node
return ErrPartyNotFound
}

return ph.Accept(sessionID, fromNode, presence)
return ph.Accept(sessionID, fromNode, presence, p.config.GetSession().SingleParty)
}

func (p *LocalPartyRegistry) PartyRemove(ctx context.Context, id uuid.UUID, node, sessionID, fromNode string, presence *rtapi.UserPresence) error {
Expand Down
15 changes: 14 additions & 1 deletion server/pipeline_party.go
Expand Up @@ -23,6 +23,8 @@ import (
"go.uber.org/zap"
)

var partyStreamMode = map[uint8]struct{}{StreamModeParty: {}}

func (p *Pipeline) partyCreate(logger *zap.Logger, session Session, envelope *rtapi.Envelope) (bool, *rtapi.Envelope) {
incoming := envelope.GetPartyCreate()

Expand Down Expand Up @@ -65,6 +67,11 @@ func (p *Pipeline) partyCreate(logger *zap.Logger, session Session, envelope *rt
return false, nil
}

if p.config.GetSession().SingleParty {
// Kick the user from any other parties they may be part of.
p.tracker.UntrackLocalByModes(session.ID(), partyStreamMode, ph.Stream)
}

out := &rtapi.Envelope{Cid: envelope.Cid, Message: &rtapi.Envelope_Party{Party: &rtapi.Party{
PartyId: ph.IDStr,
Open: incoming.Open,
Expand Down Expand Up @@ -123,7 +130,8 @@ func (p *Pipeline) partyJoin(logger *zap.Logger, session Session, envelope *rtap

// If the party was open and the join was successful, track the new member immediately.
if autoJoin {
success, _ := p.tracker.Track(session.Context(), session.ID(), PresenceStream{Mode: StreamModeParty, Subject: partyID, Label: node}, session.UserID(), PresenceMeta{
stream := PresenceStream{Mode: StreamModeParty, Subject: partyID, Label: node}
success, _ := p.tracker.Track(session.Context(), session.ID(), stream, session.UserID(), PresenceMeta{
Format: session.Format(),
Username: session.Username(),
Status: "",
Expand All @@ -135,6 +143,11 @@ func (p *Pipeline) partyJoin(logger *zap.Logger, session Session, envelope *rtap
}}}, true)
return false, nil
}

if p.config.GetSession().SingleParty {
// Kick the user from any other parties they may be part of.
p.tracker.UntrackLocalByModes(session.ID(), partyStreamMode, stream)
}
}

out := &rtapi.Envelope{Cid: envelope.Cid}
Expand Down