Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add consumer Owner ID #5157

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 16 additions & 9 deletions server/consumer.go
Expand Up @@ -109,6 +109,7 @@ type ConsumerConfig struct {

// PauseUntil is for suspending the consumer until the deadline.
PauseUntil *time.Time `json:"pause_until,omitempty"`
OwnerID string `json:"owner_id,omitempty"`
}

// SequenceInfo has both the consumer and the stream sequence and last activity.
Expand Down Expand Up @@ -2976,36 +2977,36 @@ func (o *consumer) needAck(sseq uint64, subj string) bool {
}

// Helper for the next message requests.
func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, error) {
func nextReqFromMsg(msg []byte) (time.Time, int, int, bool, time.Duration, time.Time, string, error) {
req := bytes.TrimSpace(msg)

switch {
case len(req) == 0:
return time.Time{}, 1, 0, false, 0, time.Time{}, nil
return time.Time{}, 1, 0, false, 0, time.Time{}, "", nil

case req[0] == '{':
var cr JSApiConsumerGetNextRequest
if err := json.Unmarshal(req, &cr); err != nil {
return time.Time{}, -1, 0, false, 0, time.Time{}, err
return time.Time{}, -1, 0, false, 0, time.Time{}, "", err
}
var hbt time.Time
if cr.Heartbeat > 0 {
if cr.Heartbeat*2 > cr.Expires {
return time.Time{}, 1, 0, false, 0, time.Time{}, errors.New("heartbeat value too large")
return time.Time{}, 1, 0, false, 0, time.Time{}, "", errors.New("heartbeat value too large")
}
hbt = time.Now().Add(cr.Heartbeat)
}
if cr.Expires == time.Duration(0) {
return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil
return time.Time{}, cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, cr.OwnerID, nil
}
return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, nil
return time.Now().Add(cr.Expires), cr.Batch, cr.MaxBytes, cr.NoWait, cr.Heartbeat, hbt, cr.OwnerID, nil
default:
if n, err := strconv.Atoi(string(req)); err == nil {
return time.Time{}, n, 0, false, 0, time.Time{}, nil
return time.Time{}, n, 0, false, 0, time.Time{}, "", nil
}
}

return time.Time{}, 1, 0, false, 0, time.Time{}, nil
return time.Time{}, 1, 0, false, 0, time.Time{}, "", nil
}

// Represents a request that is on the internal waiting queue
Expand Down Expand Up @@ -3321,12 +3322,18 @@ func (o *consumer) processNextMsgRequest(reply string, msg []byte) {
}

// Check payload here to see if they sent in batch size or a formal request.
expires, batchSize, maxBytes, noWait, hb, hbt, err := nextReqFromMsg(msg)
expires, batchSize, maxBytes, noWait, hb, hbt, ownerID, err := nextReqFromMsg(msg)
if err != nil {
sendErr(400, fmt.Sprintf("Bad Request - %v", err))
return
}

// Check the owner for exclusive consumer.
if o.cfg.OwnerID != _EMPTY_ && ownerID != o.cfg.OwnerID {
sendErr(412, "Consumer is owned by another client")
return
}

// Check for request limits
if o.cfg.MaxRequestBatch > 0 && batchSize > o.cfg.MaxRequestBatch {
sendErr(409, fmt.Sprintf("Exceeded MaxRequestBatch of %d", o.cfg.MaxRequestBatch))
Expand Down
1 change: 1 addition & 0 deletions server/jetstream_api.go
Expand Up @@ -730,6 +730,7 @@ type JSApiConsumerGetNextRequest struct {
MaxBytes int `json:"max_bytes,omitempty"`
NoWait bool `json:"no_wait,omitempty"`
Heartbeat time.Duration `json:"idle_heartbeat,omitempty"`
OwnerID string `json:"owner_id,omitempty"`
}

// JSApiStreamTemplateCreateResponse for creating templates.
Expand Down
120 changes: 120 additions & 0 deletions server/jetstream_consumer_test.go
Expand Up @@ -31,6 +31,126 @@ import (
"github.com/nats-io/nuid"
)

func TestJetStreamConsumerExclusive(t *testing.T) {
s := RunBasicJetStreamServer(t)
defer s.Shutdown()

nc, js := jsClientConnect(t, s)
defer nc.Close()
acc := s.GlobalAccount()

mset, err := acc.addStream(&StreamConfig{
Name: "TEST",
Retention: LimitsPolicy,
Subjects: []string{"events.>"},
MaxAge: time.Second * 90,
})
require_NoError(t, err)

_, err = mset.addConsumer(&ConsumerConfig{
Durable: "consumer",
AckPolicy: AckExplicit,
DeliverPolicy: DeliverAll,
FilterSubject: "events.>",
OwnerID: "me",
})
require_NoError(t, err)

for i := 0; i < 10; i++ {
_, err = js.Publish("events.1", []byte("hello"))
require_NoError(t, err)
}

// set ID that is not owned by us.
cr := JSApiConsumerGetNextRequest{
Batch: 1,
OwnerID: "notMe",
}
crBytes, err := json.Marshal(cr)
require_NoError(t, err)

inbox := nats.NewInbox()
err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes)
require_NoError(t, err)

consumerSub, err := nc.SubscribeSync(inbox)
require_NoError(t, err)

msg, err := consumerSub.NextMsg(time.Second)
require_NoError(t, err)

// check if message header contains error "Consumer is owned by another client"
if !strings.Contains(string(msg.Header.Get("Status")), "412") {
t.Fatalf("Expected exclusive consumer error, got %q", msg.Header.Get("Description"))
}

// now set our ID
cr = JSApiConsumerGetNextRequest{
Batch: 2,
OwnerID: "me",
}
crBytes, err = json.Marshal(cr)
require_NoError(t, err)

err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes)
require_NoError(t, err)

msg, err = consumerSub.NextMsg(time.Second)
require_NoError(t, err)
require_Equal(t, string(msg.Data), "hello")

// update the consumer to different ID
_, err = mset.addConsumer(&ConsumerConfig{
Durable: "consumer",
AckPolicy: AckExplicit,
DeliverPolicy: DeliverAll,
FilterSubject: "events.>",
OwnerID: "differentMe",
})
require_NoError(t, err)

// we should still get messages from the pending pull requests
msg, err = consumerSub.NextMsg(time.Second)
require_NoError(t, err)
require_Equal(t, string(msg.Data), "hello")

// check if the previous ID works. It should not
cr = JSApiConsumerGetNextRequest{
Batch: 1,
OwnerID: "me",
}
crBytes, err = json.Marshal(cr)
require_NoError(t, err)

err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes)
require_NoError(t, err)

msg, err = consumerSub.NextMsg(time.Second)
require_NoError(t, err)

// we should now get an error
if !strings.Contains(string(msg.Header.Get("Status")), "412") {
t.Fatalf("Expected exclusive consumer error, got %q", msg.Header.Get("Description"))
}

// and this should work now

cr = JSApiConsumerGetNextRequest{
Batch: 1,
OwnerID: "differentMe",
}
crBytes, err = json.Marshal(cr)
require_NoError(t, err)

err = nc.PublishRequest(fmt.Sprintf(JSApiRequestNextT, "TEST", "consumer"), inbox, crBytes)
require_NoError(t, err)

msg, err = consumerSub.NextMsg(time.Second)
require_NoError(t, err)
require_Equal(t, string(msg.Data), "hello")

}

func TestJetStreamConsumerMultipleFiltersRemoveFilters(t *testing.T) {

s := RunBasicJetStreamServer(t)
Expand Down
2 changes: 1 addition & 1 deletion server/jetstream_test.go
Expand Up @@ -639,7 +639,7 @@ func TestJetStreamConsumerMaxDeliveries(t *testing.T) {

func TestJetStreamNextReqFromMsg(t *testing.T) {
bef := time.Now()
expires, _, _, _, _, _, err := nextReqFromMsg([]byte(`{"expires":5000000000}`)) // nanoseconds
expires, _, _, _, _, _, _, err := nextReqFromMsg([]byte(`{"expires":5000000000}`)) // nanoseconds
require_NoError(t, err)
now := time.Now()
if expires.Before(bef.Add(5*time.Second)) || expires.After(now.Add(5*time.Second)) {
Expand Down