Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to manage empty items in a batch? #5251

Open
1 task done
bveldhoen opened this issue Dec 22, 2023 · 5 comments
Open
1 task done

How to manage empty items in a batch? #5251

bveldhoen opened this issue Dec 22, 2023 · 5 comments
Assignees
Labels
question Further information is requested

Comments

@bveldhoen
Copy link

bveldhoen commented Dec 22, 2023

Describe the question.

Thanks in advance for your help.

I'm running into an issue in a pipeline with ~11 operators. During processing, some processing steps may become irrelevant for certain items in the batch. For these empty batch items, processing should be skipped for any subsequent operators.

Currently, it seems to be required to implement workarounds for this by setting the returned Tensors to contain signal values. For some operators, I can get away by returning a Tensor (i.e. a torch Tensor from a fn.torch_python_function, or a cupy.ndarray from a fn.python_function) with a shape with first dimension set to 0, for instance (0, 640, 640, 3). But this does not always work (some operators raise exceptions), and it has been required to return bogus arrays containing -1 values in some cases. In custom operators, it is then required to test for these signal values, and to skip processing and return empty values for these empty batch items.

Below a code snippet to (hopefully) clarify:

def postprocess(previous_step_output_batch, ...):
    postprocess_output_batch = []
    for previous_step_output_sample in previous_step_output_batch:
        if (... condition that will produce a valid output sample ...):
            ....
            postprocess_output_sample = ...
        else:
            # Create sample to represent an empty batch item
            postprocess_output_sample = torch.ones(0, 640, 640, 3) * -1
            # In some scenarios (other functions, other shapes), it's required to return a shape with first dimension > 0
            # other_output_sample = torch.ones(1, 6) * -1
        postprocess_output_batch.append(postprocess_output_sample.to("cuda"))
    return postprocess_output_batch
...
def create_pipeline(...):
    ...
    postprocess_output_batch = dalitorch.fn.torch_python_function(
        previous_step_output_batch
        , function=lambda input1: postprocess(input1, ...)
        , batch_processing=True
        , device="gpu"
    )

Note that this implementation uses batch_processing=True. Would this be different/improved if using batch_processing=False? (i.e. does DALI then check for empty/None batch items?)

In general, what is the correct approach to deal with empty batch items?

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@bveldhoen bveldhoen added the question Further information is requested label Dec 22, 2023
@JanuszL
Copy link
Contributor

JanuszL commented Dec 27, 2023

Hi @bveldhoen,

Thank you for reaching out. DALI operators not necessarily can handle empty samples. Some do, some don't.
Have you tried applying conditional execution to your workflow, so if the condition determines that you should stop further processing you can go to the else branch and assign an empty sample to the output:

if condition:
    out = do_processing
else:
    out = empty

@bveldhoen
Copy link
Author

Hi @JanuszL,

Thanks for your response.

In our scenario, any number of items in the batch could be 'empty' (or not). For instance, in a batch of 4, only 1 item could be empty, with the rest containing valid items. This cascades through the subsequent operators, where each operator, that receives an empty (or invalid) item, should produce an empty (or invalid) item at the same index in the resulting batch (or resulting batches, if the number of outputs of the operator > 1).

Using conditional execution will stop the execution of the entire batch, which is not the goal in our scenario.

I think this could be implemented in a straightforward way by allowing batch items to be None (in Python), with a corresponding implementation in C++ (for instance, an is_empty flag on Tensor, or something similar?). This would require each operator to check each item for emptiness during processing, which might require a lot of changes.

For now, I'll continue using signal values (arrays filled with -1, or with a shape with first dimension set to 0).

Thanks!

@JanuszL
Copy link
Contributor

JanuszL commented Dec 27, 2023

Using conditional execution will stop the execution of the entire batch, which is not the goal in our scenario.

Despite the convenient Python syntax (if/else) the conditional execution works per sample. So:

if condition:
    out = do_processing
else:
    out = empty

each sample can take a separate execution path. Under the hood split/merge operators are added that partition samples according to the condition. It doesn't stop the execution it just redirects samples in different directions.
I recommend taking a closer look at that.

@bveldhoen
Copy link
Author

I see! I didn't know that the conditional execution was per sample, thanks for the clarification. I'll give it a try (next week). Will this work with a fn.python_function with batch_processing=True? (or is it required to do the check within the called python function in this case?)

Would this work with an additional output, containing a batch with Tensors, which contain True/False?
Something like:

output_batch_A, is_valid_batch_A = fn.python_function(...)
if is_valid_batch_A:
    output_batch_B, is_valid_batch_B = fn.B(output_batch_A, ...)
else:
    output_batch_B = empty

if is_valid_batch_A and is_valid_batch_B:
    output_batch_C = fn.C(output_batch_A, output_batch_B, ...)
else:
    output_batch_C = empty
...

@JanuszL
Copy link
Contributor

JanuszL commented Dec 27, 2023

Will this work with a fn.python_function with batch_processing=True? (or is it required to do the check within the called python function in this case?)

As far as I understand it should. However each time the Python function will get only samples in a given condition branch - between 0 and batch size.

Would this work with an additional output, containing a batch with Tensors, which contain True/False?

I think that in both branches the variable needs to be defined.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants