-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TT-922] Universal concurrent executor (#927)
* generic concurrent executor * Seth-specific utils
- Loading branch information
Showing
9 changed files
with
1,025 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package concurrency_test | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/smartcontractkit/chainlink-testing-framework/concurrency" | ||
"github.com/smartcontractkit/chainlink-testing-framework/logging" | ||
) | ||
|
||
type client struct{} | ||
|
||
func (c *client) getConcurrency() int { | ||
return 1 | ||
} | ||
|
||
func (c *client) deplyContractConfigurableFromKey(_ int, _ contractConfiguration) (ContractIntstance, error) { | ||
return ContractIntstance{}, nil | ||
} | ||
|
||
func (c *client) deplyContractFromKey(_ int) (ContractIntstance, error) { | ||
return ContractIntstance{}, nil | ||
} | ||
|
||
type ContractIntstance struct{} | ||
|
||
type contractConfiguration struct{} | ||
|
||
type contractResult struct { | ||
instance ContractIntstance | ||
} | ||
|
||
func (k contractResult) GetResult() ContractIntstance { | ||
return k.instance | ||
} | ||
|
||
func TestExampleContractsWithConfiguration(t *testing.T) { | ||
instances, err := DeployContractsWithConfiguration(&client{}, []contractConfiguration{{}, {}}) | ||
require.NoError(t, err, "failed to deploy contract instances") | ||
require.Equal(t, 2, len(instances), "expected 2 contract instances") | ||
} | ||
|
||
// DeployContractsWithConfiguration shows a very simplified method that deploys concurrently contract instances with given configurations | ||
func DeployContractsWithConfiguration(client *client, contractConfigs []contractConfiguration) ([]ContractIntstance, error) { | ||
l := logging.GetTestLogger(nil) | ||
|
||
executor := concurrency.NewConcurrentExecutor[ContractIntstance, contractResult, contractConfiguration](l) | ||
|
||
var deployContractFn = func(channel chan contractResult, errorCh chan error, executorNum int, payload contractConfiguration) { | ||
keyNum := executorNum + 1 // key 0 is the root key | ||
|
||
instance, err := client.deplyContractConfigurableFromKey(keyNum, payload) | ||
if err != nil { | ||
errorCh <- err | ||
return | ||
} | ||
|
||
channel <- contractResult{instance: instance} | ||
} | ||
|
||
results, err := executor.Execute(client.getConcurrency(), contractConfigs, deployContractFn) | ||
if err != nil { | ||
return []ContractIntstance{}, err | ||
} | ||
|
||
if len(results) != len(contractConfigs) { | ||
return []ContractIntstance{}, fmt.Errorf("expected %v results, got %v", len(contractConfigs), len(results)) | ||
} | ||
|
||
return results, nil | ||
} | ||
|
||
func TestExampleContractsWithoutConfiguration(t *testing.T) { | ||
instances, err := DeployIdenticalContracts(&client{}, 2) | ||
require.NoError(t, err, "failed to deploy contract instances") | ||
require.Equal(t, 2, len(instances), "expected 2 contract instances") | ||
} | ||
|
||
// DeployIdenticalContracts shows a very simplified method that deploys concurrently identical contract instances | ||
// which require no configuration, just need to be exected N amount of times | ||
func DeployIdenticalContracts(client *client, numberOfContracts int) ([]ContractIntstance, error) { | ||
l := logging.GetTestLogger(nil) | ||
|
||
executor := concurrency.NewConcurrentExecutor[ContractIntstance, contractResult, concurrency.NoTaskType](l) | ||
|
||
var deployContractFn = func(channel chan contractResult, errorCh chan error, executorNum int) { | ||
keyNum := executorNum + 1 // key 0 is the root key | ||
|
||
instance, err := client.deplyContractFromKey(keyNum) | ||
if err != nil { | ||
errorCh <- err | ||
return | ||
} | ||
|
||
channel <- contractResult{instance: instance} | ||
} | ||
|
||
results, err := executor.ExecuteSimple(client.getConcurrency(), numberOfContracts, deployContractFn) | ||
if err != nil { | ||
return []ContractIntstance{}, err | ||
} | ||
|
||
if len(results) != numberOfContracts { | ||
return []ContractIntstance{}, fmt.Errorf("expected %v results, got %v", numberOfContracts, len(results)) | ||
} | ||
|
||
return results, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
package concurrency | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"sync/atomic" | ||
|
||
"github.com/rs/zerolog" | ||
|
||
"github.com/smartcontractkit/chainlink-testing-framework/utils/slice" | ||
) | ||
|
||
// NoTaskType is a dummy type to be used when no task type is needed | ||
type NoTaskType struct{} | ||
|
||
type ConcurrentExecutorOpt[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any] func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) | ||
|
||
// ConcurrentExecutor is a utility to execute tasks concurrently | ||
type ConcurrentExecutor[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any] struct { | ||
results []ResultType | ||
errors []error | ||
logger zerolog.Logger | ||
failFast bool | ||
context context.Context | ||
} | ||
|
||
// NewConcurrentExecutor creates a new ConcurrentExecutor | ||
func NewConcurrentExecutor[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any](logger zerolog.Logger, opts ...ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType]) *ConcurrentExecutor[ResultType, ResultChannelType, TaskType] { | ||
c := &ConcurrentExecutor[ResultType, ResultChannelType, TaskType]{ | ||
logger: logger, | ||
results: []ResultType{}, | ||
errors: []error{}, | ||
failFast: true, | ||
context: context.Background(), | ||
} | ||
|
||
for _, opt := range opts { | ||
opt(c) | ||
} | ||
|
||
return c | ||
} | ||
|
||
// / WithContext sets the context for the executor, if not set it defaults to context.Background() | ||
func WithContext[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any](context context.Context) ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType] { | ||
return func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) { | ||
c.context = context | ||
} | ||
} | ||
|
||
// WithoutFailFast disables fail fast. Executor will wait for all tasks to finish even if some of them fail. | ||
func WithoutFailFast[ResultType any, ResultChannelType ChannelWithResult[ResultType], TaskType any]() ConcurrentExecutorOpt[ResultType, ResultChannelType, TaskType] { | ||
return func(c *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) { | ||
c.failFast = false | ||
} | ||
} | ||
|
||
// TaskProcessorFn is a function that processes a task that requires a payload. It should send the result to the resultCh and any error to the errorCh. It should | ||
// never send to both channels. The executorNum is the index of the executor that is processing the task. The payload is the task's payload. If task doesn't require | ||
// one use SimpleTaskProcessorFn instead. | ||
type TaskProcessorFn[ResultChannelType, TaskType any] func(resultCh chan ResultChannelType, errorCh chan error, executorNum int, payload TaskType) | ||
|
||
// SimpleTaskProcessorFn is a function that processes a task that doesn't require a payload. It should send the result to the resultCh and any error to the errorCh. It should | ||
// never send to both channels. The executorNum is the index of the executor that is processing the task. | ||
type SimpleTaskProcessorFn[ResultChannelType any] func(resultCh chan ResultChannelType, errorCh chan error, executorNum int) | ||
|
||
// ChannelWithResult is an interface that should be implemented by the result channel | ||
type ChannelWithResult[ResultType any] interface { | ||
GetResult() ResultType | ||
} | ||
|
||
// GetErrors returns all errors that occurred during processing | ||
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) GetErrors() []error { | ||
return e.errors | ||
} | ||
|
||
// ExecuteSimple executes a task that doesn't require a payload. It is executed repeatTimes times with given concurrency. The simpleProcessorFn is the function that processes the task. | ||
// Executor will attempt to distribute the tasks evenly among the executors. | ||
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) ExecuteSimple(concurrency int, repeatTimes int, simpleProcessorFn SimpleTaskProcessorFn[ResultChannelType]) ([]ResultType, error) { | ||
dummy := make([]TaskType, repeatTimes) | ||
for i := 0; i < repeatTimes; i++ { | ||
dummy[i] = *new(TaskType) | ||
} | ||
|
||
return e.Execute(concurrency, dummy, adaptSimpleToTaskProcessorFn[ResultChannelType, TaskType](simpleProcessorFn)) | ||
} | ||
|
||
// Execute executes a task that requires a payload. It is executed with given concurrency. The processorFn is the function that processes the task. | ||
// Executor will attempt to distribute the tasks evenly among the executors. | ||
func (e *ConcurrentExecutor[ResultType, ResultChannelType, TaskType]) Execute(concurrency int, payload []TaskType, processorFn TaskProcessorFn[ResultChannelType, TaskType]) ([]ResultType, error) { | ||
if len(payload) == 0 { | ||
return []ResultType{}, nil | ||
} | ||
|
||
if concurrency <= 0 { | ||
e.logger.Warn().Msg("Concurrency is less than 1, setting it to 1") | ||
concurrency = 1 | ||
} | ||
|
||
var wgProcesses sync.WaitGroup | ||
wgProcesses.Add(len(payload)) | ||
|
||
canSafelyContinueCh := make(chan struct{}) // waits until listening goroutine finishes, so we can safely return from the function | ||
doneProcessingCh := make(chan struct{}) // signals that both result and error channels are closed | ||
errorCh := make(chan error, len(payload)) | ||
resultCh := make(chan ResultChannelType, len(payload)) | ||
|
||
// mutex to protect shared state | ||
mutex := sync.Mutex{} | ||
|
||
// atomic counter to keep track of processed tasks | ||
var atomicCounter atomic.Int32 | ||
|
||
ctx, cancel := context.WithCancel(e.context) | ||
|
||
// listen in the background until all tasks are processed (no fail-fast) | ||
go func() { | ||
defer func() { | ||
e.logger.Trace().Msg("Finished listening to task processing results") | ||
close(canSafelyContinueCh) | ||
}() | ||
for { | ||
select { | ||
case err, ok := <-errorCh: | ||
if !ok { | ||
e.logger.Trace().Msg("Error channel closed") | ||
return | ||
} | ||
if err != nil { | ||
mutex.Lock() | ||
e.errors = append(e.errors, err) | ||
e.logger.Err(err).Msg("Error processing a task") | ||
mutex.Unlock() | ||
wgProcesses.Done() | ||
|
||
// cancel the context if failFast is enabled and it hasn't been cancelled yet | ||
if e.failFast && ctx.Err() == nil { | ||
cancel() | ||
} | ||
} | ||
case result, ok := <-resultCh: | ||
if !ok { | ||
e.logger.Trace().Msg("Result channel closed") | ||
return | ||
} | ||
|
||
counter := atomicCounter.Add(1) | ||
mutex.Lock() | ||
e.results = append(e.results, result.GetResult()) | ||
e.logger.Trace().Str("Done/Total", fmt.Sprintf("%d/%d", counter, len(payload))).Msg("Finished aggregating task result") | ||
mutex.Unlock() | ||
wgProcesses.Done() | ||
case <-doneProcessingCh: | ||
e.logger.Trace().Msg("Signaling that processing is done") | ||
return | ||
} | ||
} | ||
}() | ||
|
||
dividedPayload := slice.DivideSlice(payload, concurrency) | ||
|
||
for executorNum := 0; executorNum < concurrency; executorNum++ { | ||
go func(key int) { | ||
payloads := dividedPayload[key] | ||
|
||
if len(payloads) == 0 { | ||
return | ||
} | ||
|
||
e.logger.Debug(). | ||
Int("Executor Index", key). | ||
Int("Tasks to process", len(payloads)). | ||
Msg("Started processing tasks") | ||
|
||
for i := 0; i < len(payloads); i++ { | ||
|
||
// if context is cancelled and failFast is enabled mark all remaining tasks as finished | ||
if e.failFast && ctx.Err() != nil { | ||
e.logger.Trace(). | ||
Int("Executor Index", key). | ||
Str("Cancelled/Total", fmt.Sprintf("%d/%d", (i+1), len(payloads))). | ||
Msg("Canelling remaining tasks") | ||
wgProcesses.Done() | ||
|
||
continue | ||
} | ||
|
||
processorFn(resultCh, errorCh, key, payloads[i]) | ||
e.logger.Trace(). | ||
Int("Executor Index", key). | ||
Str("Done/Total", fmt.Sprintf("%d/%d", (i+1), len(payloads))). | ||
Msg("Processed a tasks") | ||
} | ||
|
||
e.logger.Debug(). | ||
Int("Executor Index", key). | ||
Msg("Finished processing tasks") | ||
}(executorNum) | ||
} | ||
|
||
wgProcesses.Wait() | ||
close(resultCh) | ||
close(errorCh) | ||
close(doneProcessingCh) | ||
<-canSafelyContinueCh | ||
|
||
if len(e.errors) > 0 { | ||
return []ResultType{}, fmt.Errorf("Failed to process %d task(s)", len(e.errors)) | ||
} | ||
|
||
return e.results, nil | ||
} | ||
|
||
func adaptSimpleToTaskProcessorFn[ResultChannelType any, TaskType any]( | ||
simpleFn SimpleTaskProcessorFn[ResultChannelType], | ||
) TaskProcessorFn[ResultChannelType, TaskType] { | ||
return func(resultCh chan ResultChannelType, errorCh chan error, executorNum int, _ TaskType) { | ||
simpleFn(resultCh, errorCh, executorNum) | ||
} | ||
} |
Oops, something went wrong.