Skip to content

Commit

Permalink
Merge pull request #256 from bojand/stream_message_provider
Browse files Browse the repository at this point in the history
add stream message provider api
  • Loading branch information
bojand committed Jan 10, 2021
2 parents 4c77580 + 65733b6 commit 2a63c46
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 28 deletions.
25 changes: 25 additions & 0 deletions runner/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type RunConfig struct {

dataFunc BinaryDataFunc
dataProviderFunc DataProviderFunc
dataStreamFunc StreamMessageProviderFunc
mdProviderFunc MetadataProviderFunc

funcs template.FuncMap
Expand Down Expand Up @@ -1006,6 +1007,30 @@ func WithMetadataProvider(fn MetadataProviderFunc) Option {
}
}

// WithStreamMessageProvider sets custom stream message provider
// WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
// protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.FormatInt(cd.RequestNumber, 10)}
// dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
// if err != nil {
// return nil, err
// }
//
// callCounter++
//
// if callCounter == 5 {
// err = ErrLastMessage
// }
//
// return dynamicMsg, err
// }),
func WithStreamMessageProvider(fn StreamMessageProviderFunc) Option {
return func(o *RunConfig) error {
o.dataStreamFunc = fn

return nil
}
}

func createClientTransportCredentials(skipVerify bool, cacertFile, clientCertFile, clientKeyFile, cname string) (credentials.TransportCredentials, error) {
var tlsConf tls.Config

Expand Down
1 change: 1 addition & 0 deletions runner/requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error {
dataProvider: b.dataProvider,
metadataProvider: b.metadataProvider,
streamRecv: b.config.recvMsgFunc,
msgProvider: b.config.dataStreamFunc,
}

wc++ // increment worker id
Expand Down
142 changes: 141 additions & 1 deletion runner/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ func TestRunUnary(t *testing.T) {
}
}

assert.Equal(t, []string{"0", "__record_metadata__||token:secret2", "2", "__record_metadata__||token:secret4", "4"}, names)
assert.Equal(t, []string{"0", "__record_metadata__||token:secret1", "2", "__record_metadata__||token:secret3", "4"}, names)
})
}

Expand Down Expand Up @@ -1553,6 +1553,76 @@ func TestRunClientStreaming(t *testing.T) {
}
}
})

t.Run("with stream message provider", func(t *testing.T) {
gs.ResetCounters()

callCounter := 0

report, err := Run(
"helloworld.Greeter.SayHelloCS",
internal.TestLocalhost,
WithProtoFile("../testdata/greeter.proto", []string{}),
WithTotalRequests(1),
WithConcurrency(1),
WithTimeout(time.Duration(20*time.Second)),
WithDialTimeout(time.Duration(20*time.Second)),
WithInsecure(true),
WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.Itoa(callCounter)}
dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
if err != nil {
return nil, err
}

callCounter++

if callCounter == 5 {
err = ErrLastMessage
}

return dynamicMsg, err
}),
)

assert.NoError(t, err)

assert.NotNil(t, report)

assert.NotZero(t, report.Total)
assert.Equal(t, 1, int(report.Count))
assert.NotZero(t, report.Average)
assert.NotZero(t, report.Fastest)
assert.NotZero(t, report.Slowest)
assert.NotZero(t, report.Rps)
assert.Empty(t, report.Name)
assert.NotEmpty(t, report.Date)
assert.NotEmpty(t, report.Details)
assert.NotEmpty(t, report.Options)
assert.Equal(t, true, report.Options.Insecure)
assert.NotEmpty(t, report.LatencyDistribution)
assert.Equal(t, ReasonNormalEnd, report.EndReason)
assert.Empty(t, report.ErrorDist)

assert.Equal(t, report.Average, report.Slowest)
assert.Equal(t, report.Average, report.Fastest)
assert.Equal(t, report.Slowest, report.Fastest)

count := gs.GetCount(callType)
assert.Equal(t, 1, count)

connCount := gs.GetConnectionCount()
assert.Equal(t, 1, connCount)

calls := gs.GetCalls(callType)
assert.NotNil(t, calls)
assert.Len(t, calls, 1)
msgs := calls[0]
assert.Len(t, msgs, 5)

assert.Equal(t, "g0c0: 0", msgs[0].GetName())
assert.Equal(t, "g0c0: 4", msgs[4].GetName())
})
}

func TestRunClientStreamingBinary(t *testing.T) {
Expand Down Expand Up @@ -2274,6 +2344,76 @@ func TestRunBidi(t *testing.T) {
msgs := calls[0]
assert.Len(t, msgs, 6)
})

t.Run("with stream message provider", func(t *testing.T) {
gs.ResetCounters()

callCounter := 0

report, err := Run(
"helloworld.Greeter.SayHelloBidi",
internal.TestLocalhost,
WithProtoFile("../testdata/greeter.proto", []string{}),
WithTotalRequests(1),
WithConcurrency(1),
WithTimeout(time.Duration(20*time.Second)),
WithDialTimeout(time.Duration(20*time.Second)),
WithInsecure(true),
WithStreamMessageProvider(func(cd *CallData) (*dynamic.Message, error) {
protoMsg := &helloworld.HelloRequest{Name: cd.WorkerID + ": " + strconv.Itoa(callCounter)}
dynamicMsg, err := dynamic.AsDynamicMessage(protoMsg)
if err != nil {
return nil, err
}

callCounter++

if callCounter == 7 {
err = ErrLastMessage
}

return dynamicMsg, err
}),
)

assert.NoError(t, err)

assert.NotNil(t, report)

assert.NotZero(t, report.Total)
assert.Equal(t, 1, int(report.Count))
assert.NotZero(t, report.Average)
assert.NotZero(t, report.Fastest)
assert.NotZero(t, report.Slowest)
assert.NotZero(t, report.Rps)
assert.Empty(t, report.Name)
assert.NotEmpty(t, report.Date)
assert.NotEmpty(t, report.Details)
assert.NotEmpty(t, report.Options)
assert.NotEmpty(t, report.LatencyDistribution)
assert.Equal(t, ReasonNormalEnd, report.EndReason)
assert.Equal(t, true, report.Options.Insecure)
assert.Empty(t, report.ErrorDist)

assert.Equal(t, report.Average, report.Slowest)
assert.Equal(t, report.Average, report.Fastest)
assert.Equal(t, report.Slowest, report.Fastest)

count := gs.GetCount(callType)
assert.Equal(t, 1, count)

connCount := gs.GetConnectionCount()
assert.Equal(t, 1, connCount)

calls := gs.GetCalls(callType)
assert.NotNil(t, calls)
assert.Len(t, calls, 1)
msgs := calls[0]
assert.Len(t, msgs, 7)

assert.Equal(t, "g0c0: 0", msgs[0].GetName())
assert.Equal(t, "g0c0: 6", msgs[6].GetName())
})
}

func TestRunUnarySecure(t *testing.T) {
Expand Down
58 changes: 31 additions & 27 deletions runner/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ type Worker struct {

dataProvider DataProviderFunc
metadataProvider MetadataProviderFunc
streamRecv StreamRecvMsgInterceptFunc
msgProvider StreamMessageProviderFunc

streamRecv StreamRecvMsgInterceptFunc
}

func (w *Worker) runWorker() error {
Expand Down Expand Up @@ -80,14 +82,6 @@ func (w *Worker) makeRequest(tv TickValue) error {

ctd := newCallData(w.mtd, w.config.funcs, w.workerID, reqNum)

inputs, err := w.dataProvider(ctd)
if err != nil {
return err
}
if len(inputs) == 0 {
return fmt.Errorf("no data provided for request")
}

reqMD, err := w.metadataProvider(ctd)
if err != nil {
return err
Expand All @@ -112,25 +106,15 @@ func (w *Worker) makeRequest(tv TickValue) error {
ctx = metadata.NewOutgoingContext(ctx, *reqMD)
}

var callType string
if w.config.hasLog {
callType = "unary"
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
callType = "bidi"
} else if w.mtd.IsServerStreaming() {
callType = "server-streaming"
} else if w.mtd.IsClientStreaming() {
callType = "client-streaming"
}

w.config.log.Debugw("Making request", "workerID", w.workerID,
"call type", callType, "call", w.mtd.GetFullyQualifiedName(),
"input", inputs, "metadata", reqMD)
inputs, err := w.dataProvider(ctd)
if err != nil {
return err
}

unaryInput := inputs[0]
var msgProvider StreamMessageProviderFunc
if w.mtd.IsClientStreaming() {
if w.msgProvider != nil {
msgProvider = w.msgProvider
} else if w.mtd.IsClientStreaming() {
if w.config.streamDynamicMessages {
mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount)
if err != nil {
Expand All @@ -148,15 +132,35 @@ func (w *Worker) makeRequest(tv TickValue) error {
}
}

if len(inputs) == 0 && msgProvider == nil {
return fmt.Errorf("no data provided for request")
}

var callType string
if w.config.hasLog {
callType = "unary"
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
callType = "bidi"
} else if w.mtd.IsServerStreaming() {
callType = "server-streaming"
} else if w.mtd.IsClientStreaming() {
callType = "client-streaming"
}

w.config.log.Debugw("Making request", "workerID", w.workerID,
"call type", callType, "call", w.mtd.GetFullyQualifiedName(),
"input", inputs, "metadata", reqMD)
}

// RPC errors are handled via stats handler
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
_ = w.makeBidiRequest(&ctx, ctd, msgProvider)
} else if w.mtd.IsClientStreaming() {
_ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider)
} else if w.mtd.IsServerStreaming() {
_ = w.makeServerStreamingRequest(&ctx, unaryInput)
_ = w.makeServerStreamingRequest(&ctx, inputs[0])
} else {
_ = w.makeUnaryRequest(&ctx, reqMD, unaryInput)
_ = w.makeUnaryRequest(&ctx, reqMD, inputs[0])
}

return err
Expand Down

0 comments on commit 2a63c46

Please sign in to comment.