Skip to content

Commit

Permalink
[TT-922] Universal concurrent executor (#927)
Browse files Browse the repository at this point in the history
* generic concurrent executor
* Seth-specific utils
  • Loading branch information
Tofel committed Apr 29, 2024
1 parent acc490a commit a89f080
Show file tree
Hide file tree
Showing 9 changed files with 1,025 additions and 4 deletions.
110 changes: 110 additions & 0 deletions concurrency/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package concurrency_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-testing-framework/concurrency"
"github.com/smartcontractkit/chainlink-testing-framework/logging"
)

type client struct{}

func (c *client) getConcurrency() int {
return 1
}

func (c *client) deplyContractConfigurableFromKey(_ int, _ contractConfiguration) (ContractIntstance, error) {
return ContractIntstance{}, nil
}

func (c *client) deplyContractFromKey(_ int) (ContractIntstance, error) {
return ContractIntstance{}, nil
}

type ContractIntstance struct{}

type contractConfiguration struct{}

type contractResult struct {
instance ContractIntstance
}

func (k contractResult) GetResult() ContractIntstance {
return k.instance
}

func TestExampleContractsWithConfiguration(t *testing.T) {
instances, err := DeployContractsWithConfiguration(&client{}, []contractConfiguration{{}, {}})
require.NoError(t, err, "failed to deploy contract instances")
require.Equal(t, 2, len(instances), "expected 2 contract instances")
}

// DeployContractsWithConfiguration shows a very simplified method that deploys concurrently contract instances with given configurations
func DeployContractsWithConfiguration(client *client, contractConfigs []contractConfiguration) ([]ContractIntstance, error) {
l := logging.GetTestLogger(nil)

executor := concurrency.NewConcurrentExecutor[ContractIntstance, contractResult, contractConfiguration](l)

var deployContractFn = func(channel chan contractResult, errorCh chan error, executorNum int, payload contractConfiguration) {
keyNum := executorNum + 1 // key 0 is the root key

instance, err := client.deplyContractConfigurableFromKey(keyNum, payload)
if err != nil {
errorCh <- err
return
}

channel <- contractResult{instance: instance}
}

results, err := executor.Execute(client.getConcurrency(), contractConfigs, deployContractFn)
if err != nil {
return []ContractIntstance{}, err
}

if len(results) != len(contractConfigs) {
return []ContractIntstance{}, fmt.Errorf("expected %v results, got %v", len(contractConfigs), len(results))
}

return results, nil
}

func TestExampleContractsWithoutConfiguration(t *testing.T) {
instances, err := DeployIdenticalContracts(&client{}, 2)
require.NoError(t, err, "failed to deploy contract instances")
require.Equal(t, 2, len(instances), "expected 2 contract instances")
}

// DeployIdenticalContracts shows a very simplified method that deploys concurrently identical contract instances
// which require no configuration, just need to be exected N amount of times
func DeployIdenticalContracts(client *client, numberOfContracts int) ([]ContractIntstance, error) {
l := logging.GetTestLogger(nil)

executor := concurrency.NewConcurrentExecutor[ContractIntstance, contractResult, concurrency.NoTaskType](l)

var deployContractFn = func(channel chan contractResult, errorCh chan error, executorNum int) {
keyNum := executorNum + 1 // key 0 is the root key

instance, err := client.deplyContractFromKey(keyNum)
if err != nil {
errorCh <- err
return
}

channel <- contractResult{instance: instance}
}

results, err := executor.ExecuteSimple(client.getConcurrency(), numberOfContracts, deployContractFn)
if err != nil {
return []ContractIntstance{}, err
}

if len(results) != numberOfContracts {
return []ContractIntstance{}, fmt.Errorf("expected %v results, got %v", numberOfContracts, len(results))
}

return results, nil
}
221 changes: 221 additions & 0 deletions concurrency/executor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package concurrency

import (
"context"
"fmt"
"sync"
"sync/atomic"

"github.com/rs/zerolog"

"github.com/smartcontractkit/chainlink-testing-framework/utils/slice"
)

// NoTaskType is a dummy type to be used when no task type is needed
type NoTaskType struct{}

type ConcurrentExecutorOpt[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any] func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType])

// ConcurrentExecutor is a utility to execute tasks concurrently
type ConcurrentExecutor[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any] struct {
results []ResultType
errors []error
logger zerolog.Logger
failFast bool
context context.Context
}

// NewConcurrentExecutor creates a new ConcurrentExecutor
func NewConcurrentExecutor[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any](logger zerolog.Logger, opts ...ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType]) *ConcurrentExecutor[ResultType, ResultChannelType, TaskType] {
c := &ConcurrentExecutor[ResultType, ResultChannelType, TaskType]{
logger: logger,
results: []ResultType{},
errors: []error{},
failFast: true,
context: context.Background(),
}

for _, opt := range opts {
opt(c)
}

return c
}

// / WithContext sets the context for the executor, if not set it defaults to context.Background()
func WithContext[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any](context context.Context) ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType] {
return func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) {
c.context = context
}
}

// WithoutFailFast disables fail fast. Executor will wait for all tasks to finish even if some of them fail.
func WithoutFailFast[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any]() ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType] {
return func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) {
c.failFast = false
}
}

// TaskProcessorFn is a function that processes a task that requires a payload. It should send the result to the resultCh and any error to the errorCh. It should
// never send to both channels. The executorNum is the index of the executor that is processing the task. The payload is the task's payload. If task doesn't require
// one use SimpleTaskProcessorFn instead.
type TaskProcessorFn[ResultChannelType, TaskType any] func(resultCh chan ResultChannelType, errorCh chan error, executorNum int, payload TaskType)

// SimpleTaskProcessorFn is a function that processes a task that doesn't require a payload. It should send the result to the resultCh and any error to the errorCh. It should
// never send to both channels. The executorNum is the index of the executor that is processing the task.
type SimpleTaskProcessorFn[ResultChannelType any] func(resultCh chan ResultChannelType, errorCh chan error, executorNum int)

// ChannelWithResult is an interface that should be implemented by the result channel
type ChannelWithResult[ResultType any] interface {
GetResult() ResultType
}

// GetErrors returns all errors that occurred during processing
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) GetErrors() []error {
return e.errors
}

// ExecuteSimple executes a task that doesn't require a payload. It is executed repeatTimes times with given concurrency. The simpleProcessorFn is the function that processes the task.
// Executor will attempt to distribute the tasks evenly among the executors.
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) ExecuteSimple(concurrency int, repeatTimes int, simpleProcessorFn SimpleTaskProcessorFn[ResultChannelType]) ([]ResultType, error) {
dummy := make([]TaskType, repeatTimes)
for i := 0; i < repeatTimes; i++ {
dummy[i] = *new(TaskType)
}

return e.Execute(concurrency, dummy, adaptSimpleToTaskProcessorFn[ResultChannelType, TaskType](simpleProcessorFn))
}

// Execute executes a task that requires a payload. It is executed with given concurrency. The processorFn is the function that processes the task.
// Executor will attempt to distribute the tasks evenly among the executors.
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) Execute(concurrency int, payload []TaskType, processorFn TaskProcessorFn[ResultChannelType, TaskType]) ([]ResultType, error) {
if len(payload) == 0 {
return []ResultType{}, nil
}

if concurrency <= 0 {
e.logger.Warn().Msg("Concurrency is less than 1, setting it to 1")
concurrency = 1
}

var wgProcesses sync.WaitGroup
wgProcesses.Add(len(payload))

canSafelyContinueCh := make(chan struct{}) // waits until listening goroutine finishes, so we can safely return from the function
doneProcessingCh := make(chan struct{}) // signals that both result and error channels are closed
errorCh := make(chan error, len(payload))
resultCh := make(chan ResultChannelType, len(payload))

// mutex to protect shared state
mutex := sync.Mutex{}

// atomic counter to keep track of processed tasks
var atomicCounter atomic.Int32

ctx, cancel := context.WithCancel(e.context)

// listen in the background until all tasks are processed (no fail-fast)
go func() {
defer func() {
e.logger.Trace().Msg("Finished listening to task processing results")
close(canSafelyContinueCh)
}()
for {
select {
case err, ok := <-errorCh:
if !ok {
e.logger.Trace().Msg("Error channel closed")
return
}
if err != nil {
mutex.Lock()
e.errors = append(e.errors, err)
e.logger.Err(err).Msg("Error processing a task")
mutex.Unlock()
wgProcesses.Done()

// cancel the context if failFast is enabled and it hasn't been cancelled yet
if e.failFast && ctx.Err() == nil {
cancel()
}
}
case result, ok := <-resultCh:
if !ok {
e.logger.Trace().Msg("Result channel closed")
return
}

counter := atomicCounter.Add(1)
mutex.Lock()
e.results = append(e.results, result.GetResult())
e.logger.Trace().Str("Done/Total", fmt.Sprintf("%d/%d", counter, len(payload))).Msg("Finished aggregating task result")
mutex.Unlock()
wgProcesses.Done()
case <-doneProcessingCh:
e.logger.Trace().Msg("Signaling that processing is done")
return
}
}
}()

dividedPayload := slice.DivideSlice(payload, concurrency)

for executorNum := 0; executorNum < concurrency; executorNum++ {
go func(key int) {
payloads := dividedPayload[key]

if len(payloads) == 0 {
return
}

e.logger.Debug().
Int("Executor Index", key).
Int("Tasks to process", len(payloads)).
Msg("Started processing tasks")

for i := 0; i < len(payloads); i++ {

// if context is cancelled and failFast is enabled mark all remaining tasks as finished
if e.failFast && ctx.Err() != nil {
e.logger.Trace().
Int("Executor Index", key).
Str("Cancelled/Total", fmt.Sprintf("%d/%d", (i+1), len(payloads))).
Msg("Canelling remaining tasks")
wgProcesses.Done()

continue
}

processorFn(resultCh, errorCh, key, payloads[i])
e.logger.Trace().
Int("Executor Index", key).
Str("Done/Total", fmt.Sprintf("%d/%d", (i+1), len(payloads))).
Msg("Processed a tasks")
}

e.logger.Debug().
Int("Executor Index", key).
Msg("Finished processing tasks")
}(executorNum)
}

wgProcesses.Wait()
close(resultCh)
close(errorCh)
close(doneProcessingCh)
<-canSafelyContinueCh

if len(e.errors) > 0 {
return []ResultType{}, fmt.Errorf("Failed to process %d task(s)", len(e.errors))
}

return e.results, nil
}

func adaptSimpleToTaskProcessorFn[ResultChannelType any, TaskType any](
simpleFn SimpleTaskProcessorFn[ResultChannelType],
) TaskProcessorFn[ResultChannelType, TaskType] {
return func(resultCh chan ResultChannelType, errorCh chan error, executorNum int, _ TaskType) {
simpleFn(resultCh, errorCh, executorNum)
}
}

0 comments on commit a89f080

Please sign in to comment.