Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provides a workaround for unreasonable overhead encountered in prepro… #303

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jsmcibm
Copy link
Collaborator

@jsmcibm jsmcibm commented Aug 26, 2022

…cessors - specifically in datasets.map applied to the tokenizer

PrimeQA Pull Request

What does this PR do?

provides a workaround for unreasonable overhead encountered in preprocessors - specifically in datasets.map applied to the tokenizer

Closes #(issue)

Notes:

  • Replace (issue) above ↑↑↑ with the issue this PR closes to automatically link the two.
    This must be done when the PR is created.
  • Add multiple Closes #(issue) as needed.
  • If this PR is work towards but does not close an issue, simply tag the issue without mentioning Closes.

Description

Describe the changes proposed by this PR below to give the reviewer context below ↓↓↓

Wraps the output of the tokenizer in a dictionary of np.arrays - datasets.map is observed to be much faster with this data structure than with standard tokenizer output object.

(description)

Request Review

Be sure to request a review from one or more reviewers (unless the PR is to an unprotected branch).

Versioning

When opening a PR to make changes to PrimeQA (i.e. primeqa/) master, be sure to increment the version following
semantic versioning. The VERSION is stored here
and is incremented using bump2version {patch,minor,major} as described in the development guide documentation (https://github.com/primeqa/primeqa/blob/main/docs/development.md).

  • Have you updated the VERSION?
  • Or does this PR not change the primeqa package or was not into master?

After pulling in changes from master to an existing PR, ensure the VERSION is updated appropriately.
This may require bumping the version again if it has been previously bumped.

If you're not quite ready yet to post a PR for review, feel free to open a draft PR.

Releases

After Merging

If merging into master and VERSION was updated, after this PR is merged:

Checklist

Review the following and mark as completed:

…cessors - specifically in datasets.map applied to the tokenizer
@jsmcibm
Copy link
Collaborator Author

jsmcibm commented Aug 26, 2022

To verify:
insert
sys.exit(0)
after
eval_examples, eval_dataset = preprocessor.process_eval(eval_examples)
in run_mrc.py.
Run with profiling:

python  -m cProfile -o profile.baseline.stats primeqa/mrc/run_mrc.py --model_name_or_path PrimeQA/tydiqa-primary-task-xlm-roberta-large --output_dir $OUTPUTDIR --fp16 --per_device_eval_batch_size 128 --overwrite_output_dir --do_boolean --boolean_config primeqa/boolqa/tydi_boolqa_config.json --max_eval_samples 1000 --overwrite_cache

Examine profile stats with:

python -m pstats profile.baseline.stats
- sort cumtime
- stats base|arrow```


`_process_batch` is where the actual work occurs.  `_process_eval` is where `datasets.map` is called.  The difference in time is overhead.

Before the fix:

```   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  132.789  132.789 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:77(process_eval)
        1    0.000    0.000  132.789  132.789 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:80(_process)
      6/5    0.000    0.000  132.787   26.557 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:508(wrapper)
        5    0.000    0.000  132.786   26.557 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:549(wrapper)
        4    0.000    0.000  132.784   33.196 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2216(map)
        4    0.469    0.117  132.220   33.055 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2512(_map_single)
        4    0.001    0.000  109.637   27.409 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_writer.py:480(write_batch)
   120/58    3.696    0.031  109.320    1.885 {pyarrow.lib.array}
       58    0.633    0.011  109.320    1.885 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_writer.py:150(__arrow_array__)
       29    0.000    0.000   52.832    1.822 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_writer.py:110(get_inferred_type)
     2002    0.010    0.000   21.076    0.011 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2642(apply_function_on_filtered_inputs)
     2002    0.006    0.000   21.064    0.011 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2340(decorated)
        1    0.001    0.001   17.567   17.567 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:102(_process_batch)

after the fix:

        1    0.000    0.000   27.436   27.436 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:89(process_eval)
        1    0.000    0.000   27.436   27.436 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:92(_process)
      6/5    0.000    0.000   27.434    5.487 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:508(wrapper)
        5    0.000    0.000   27.433    5.487 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:549(wrapper)
        4    0.000    0.000   27.431    6.858 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2216(map)
        4    0.007    0.002   26.864    6.716 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2512(_map_single)
     2002    0.009    0.000   21.757    0.011 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2642(apply_function_on_filtered_inputs)
     2002    0.493    0.000   21.745    0.011 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/arrow_dataset.py:2340(decorated)
        1    0.000    0.000   17.776   17.776 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:24(inner)
        1    0.001    0.001   17.758   17.758 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:114(_process_batch)

@jsmcibm
Copy link
Collaborator Author

jsmcibm commented Sep 15, 2022

Regarding NQ dataset: I ran

export TOKENIZERS_PARALLELISM=1
python -m cProfile -o profile.stats /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/run_mrc.py \
--model_name_or_path  roberta-large  \
--num_train_epochs 1 \
--output_dir foo \
--overwrite_output_dir \
--do_train \
--do_eval \
--evaluation_strategy no \
--cache_dir foo \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 64 \
--gradient_accumulation_steps 4 \
--warmup_ratio 0.1 \
--save_steps 50000 \
--learning_rate 3e-5 \
--eval_metrics NQF1 \
--postprocessor primeqa.mrc.processors.postprocessors.natural_questions.NaturalQuestionsPostProcessor \
--preprocessor primeqa.mrc.processors.preprocessors.natural_questions.NaturalQuestionsPreProcessor \
--preprocessing_num_workers 10 \
--dataset_name natural_questions \
--dataset_config_name default \
--no_cuda \
--max_train_samples 1000 \
--max_eval_samples 1000 \
--beam_runner DirectRunner

Turns out a huge amount of time was spent in load_dataset:

1    0.000    0.000 44748.003 44748.003 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/run_mrc.py:299(main)        
1    0.000    0.000 44215.938 44215.938 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/load.py:1527(load_dataset)
1    0.000    0.000  521.567  521.567 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:86(process_train)
1    0.001    0.001  521.567  521.567 /dccstor/jsmc-nmt-01/bool/git/primeqa/primeqa/mrc/processors/preprocessors/base.py:92(_process)

I can't find the corresponding _process_batch in the profiling output. But

315203    0.441    0.000   36.971    0.000 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/features/features.py:349(cast_to_python_objects)
11373842/315203   28.428    0.000   36.530    0.000 /dccstor/jsmc-nmt-01/anaconda3/envs/primeqa/lib/python3.7/site-packages/datasets/features/features.py:260(_cast_to_python_objects)

is still suspicious.
But if 12 hours is spent in load_dataset before the preprocessor is even called, it may not be worth dealing with the inefficiency in the tokenizer.

@avisil
Copy link
Collaborator

avisil commented Dec 6, 2022

should this PR be closed @jsmcibm ?

@jsmcibm
Copy link
Collaborator Author

jsmcibm commented Dec 7, 2022

During code review, Bhavani was unable to replicate the problem. I suspect that there is some additional factor (python version, etc.) that we haven't identified that is influencing the behavior in Datasets.
Note that the real problem is in Datasets - I tried to produce a simple test case to report, but that turned out to be much more complicated than I expected.
Can we icebox rather than close?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants