Skip to content

Commit

Permalink
Fix multiprocessing lambda pickling (#311)
Browse files Browse the repository at this point in the history
* Fix running plots in parallel

The reason the plots were running slower before this change is because I was
calling the plot function, not passing it to `submit`. So it was essentially
running in serial, but worse because it was still spinning up/down the
processes.

* Fix multiprocessing lambda pickling (#20)

* Refactor process_futures to be a dict

This makes debugging much easier because you can associate the arguments to the
future with the results.

* Fix the pickling error when running in multiprocessing

Only top-level functions (not lambdas) can be pickled to use in multiprocessing
pools, thus the lambdas are converted to a regular function.

* Further fixes to pickling multiprocessing error (#21)

* Refactor process_futures to be a dict

This makes debugging much easier because you can associate the arguments to the
future with the results.

* Fix the pickling error when running in multiprocessing

Only top-level functions (not lambdas) can be pickled to use in multiprocessing
pools, thus the lambdas are converted to a regular function.

* Use Counter instead of defaultdict in CRISPRessoCORE

* Update process_futures to dict in Batch and Aggregate
  • Loading branch information
Colelyman committed Jul 6, 2023
1 parent ebb016d commit de05533
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion CRISPResso2/CRISPRessoAggregateCORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def main():

if n_processes > 1:
process_pool = ProcessPoolExecutor(n_processes)
process_futures = []
process_futures = {}
else:
process_pool = None
process_futures = None
Expand Down
2 changes: 1 addition & 1 deletion CRISPResso2/CRISPRessoBatchCORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def main():

if n_processes_for_batch > 1:
process_pool = ProcessPoolExecutor(n_processes_for_batch)
process_futures = []
process_futures = {}
else:
process_pool = None
process_futures = None
Expand Down
22 changes: 10 additions & 12 deletions CRISPResso2/CRISPRessoCORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
running_python3 = True

import argparse
from collections import defaultdict
from collections import Counter
from copy import deepcopy
from concurrent.futures import ProcessPoolExecutor, wait
from functools import partial
Expand Down Expand Up @@ -88,8 +88,6 @@ def check_program(binary_name,download_url=None):
error('You can download it here:%s' % download_url)
sys.exit(1)



def get_avg_read_length_fastq(fastq_filename):
cmd=('z' if fastq_filename.endswith('.gz') else '' ) +('cat < \"%s\"' % fastq_filename)+\
r''' | awk 'BN {n=0;s=0;} NR%4 == 2 {s+=length($0);n++;} END { printf("%d\n",s/n)}' '''
Expand Down Expand Up @@ -2456,14 +2454,14 @@ def get_prime_editing_guides(this_amp_seq, this_amp_name, ref0_seq, prime_edited
deletion_length_vectors [ref_name] = np.zeros(this_len_amplicon)


inserted_n_dicts [ref_name] = defaultdict(int)
deleted_n_dicts [ref_name] = defaultdict(int)
substituted_n_dicts [ref_name] = defaultdict(int)
effective_len_dicts [ref_name] = defaultdict(int)
inserted_n_dicts [ref_name] = Counter()
deleted_n_dicts [ref_name] = Counter()
substituted_n_dicts [ref_name] = Counter()
effective_len_dicts [ref_name] = Counter()

hists_inframe [ref_name] = defaultdict(int)
hists_inframe [ref_name] = Counter()
hists_inframe [ref_name][0] = 0
hists_frameshift [ref_name] = defaultdict(int)
hists_frameshift [ref_name] = Counter()
hists_frameshift [ref_name][0] = 0
#end initialize data structures for each ref
def get_allele_row(reference_name, variant_count, aln_ref_names_str, aln_ref_scores_str, variant_payload, write_detailed_allele_table):
Expand Down Expand Up @@ -3405,7 +3403,7 @@ def count_alternate_alleles(sub_base_vectors, ref_name, ref_sequence, ref_total_

if n_processes > 1:
process_pool = ProcessPoolExecutor(n_processes)
process_futures = []
process_futures = {}
else:
process_pool = None
process_futures = None
Expand Down Expand Up @@ -4408,9 +4406,9 @@ def count_alternate_alleles(sub_base_vectors, ref_name, ref_sequence, ref_total_
global_NON_MODIFIED_NON_FRAMESHIFT = 0
global_SPLICING_SITES_MODIFIED = 0

global_hists_frameshift = defaultdict(lambda :0)
global_hists_frameshift = Counter()
global_hists_frameshift[0] = 0 # fill with at least the zero value (in case there are no others)
global_hists_inframe = defaultdict(lambda :0)
global_hists_inframe = Counter()
global_hists_inframe[0] = 0

global_count_total = 0
Expand Down
3 changes: 1 addition & 2 deletions CRISPResso2/CRISPRessoMultiProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def run_plot(plot_func, plot_args, num_processes, process_futures, process_pool)
None
"""
if num_processes > 1:
process_futures.append(process_pool.submit(plot_func, **plot_args))

process_futures[process_pool.submit(plot_func, **plot_args)] = (plot_func, plot_args)
else:
plot_func(**plot_args)

0 comments on commit de05533

Please sign in to comment.