Skip to content

Commit

Permalink
Rename maxbuf to bufsize
Browse files Browse the repository at this point in the history
  • Loading branch information
nvictus committed Mar 20, 2024
1 parent f42fe38 commit 68a90b0
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/cooler/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

def merge_breakpoints(
indexes: list[np.ndarray | h5py.Dataset],
maxbuf: int
bufsize: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Given ``k`` bin1_offset indexes, determine how to partition the data from
Expand All @@ -54,17 +54,17 @@ def merge_breakpoints(
The paritition is a subsequence of bin1 IDs, defining the bounds of chunks
of data that will be loaded into memory from each table in a single "epoch"
of merging data. The bounds are calculated such that no single epoch will
load more than ``maxbuf`` records into memory.
load more than ``bufsize`` records into memory.
However, the ``maxbuf`` condition is not guaranteed and a warning will be
However, the ``bufsize`` condition is not guaranteed and a warning will be
raised if it cannot be satisfied for one or more epochs (see Notes).
Parameters
----------
indexes : sequence of 1D arrays of equal length
indexes : sequence of 1D array-like of equal length
Offset arrays that map bin1 IDs to their offset locations in a
corresponding pixel table.
maxbuf : int
bufsize : int
Maximum number of pixel records loaded into memory in a single merge
epoch.
Expand All @@ -79,11 +79,13 @@ def merge_breakpoints(
Notes
-----
The one exception to the post-condition is when a single bin1 increment in
a table contains more than ``maxbuf`` records.
a table contains more than ``bufsize`` records.
"""
# This is a "virtual" cumulative index if all the tables were concatenated
# and sorted but no pixel records were aggregated. It helps us track how
# many records would be processed at each merge epoch.
# NOTE: We sum these incrementally in case the indexes are lazy to avoid
# loading all indexes into memory at once.
combined_index = np.zeros(indexes[0].shape)
for i in range(len(indexes)):
combined_index += indexes[i]
Expand All @@ -97,12 +99,12 @@ def merge_breakpoints(
# Find the next bin1 ID from the combined index
hi = bisect_right(
combined_index,
min(combined_start + maxbuf, combined_nnz),
min(combined_start + bufsize, combined_nnz),
lo=lo
) - 1

if hi == lo:
# This means number of records to nearest mark exceeds `maxbuf`.
# This means number of records to nearest mark exceeds `bufsize`.
# Check for oversized chunks afterwards.
hi += 1

Expand All @@ -119,10 +121,10 @@ def merge_breakpoints(
cum_nrecords = np.array(cum_nrecords)

nrecords_per_epoch = np.diff(cum_nrecords)
n_over = (nrecords_per_epoch > maxbuf).sum()
n_over = (nrecords_per_epoch > bufsize).sum()
if n_over > 0:
warnings.warn(
f"{n_over} merge epochs will require buffering more than {maxbuf} "
f"{n_over} merge epochs will require buffering more than {bufsize} "
f"pixel records, with as many as {nrecords_per_epoch.max()}."
)

Expand All @@ -138,12 +140,12 @@ class CoolerMerger(ContactBinner):
def __init__(
self,
coolers: list[Cooler],
maxbuf: int,
mergebuf: int,
columns: list[str] | None = None,
agg: dict[str, Any] | None = None
):
self.coolers = list(coolers)
self.maxbuf = maxbuf
self.mergebuf = mergebuf
self.columns = ["count"] if columns is None else columns
self.agg = {col: "sum" for col in self.columns}
if agg is not None:
Expand Down Expand Up @@ -171,7 +173,7 @@ def __iter__(self) -> Iterator[dict[str, np.ndarray]]:

# Calculate the common partition of bin1 offsets that define the epochs
# of merging data.
bin1_partition, cum_nrecords = merge_breakpoints(indexes, self.maxbuf)
bin1_partition, cum_nrecords = merge_breakpoints(indexes, self.mergebuf)
nrecords_per_epoch = np.diff(cum_nrecords)
logger.info(f"n_merge_epochs: {len(nrecords_per_epoch)}")
logger.debug(f"bin1_partition: {bin1_partition}")
Expand All @@ -183,7 +185,7 @@ def __iter__(self) -> Iterator[dict[str, np.ndarray]]:
starts = [0] * len(self.coolers)
for bp in bin1_partition[1:]:
stops = [index[bp] for index in indexes]
logger.info(f"records merged: {stops}")
logger.info(f"records consumed: {stops}")

# extract, concat
combined = pd.concat(
Expand Down Expand Up @@ -292,7 +294,7 @@ def merge_coolers(

bins = clrs[0].bins()[["chrom", "start", "end"]][:]
assembly = clrs[0].info.get("genome-assembly", None)
iterator = CoolerMerger(clrs, maxbuf=mergebuf, columns=columns, agg=agg)
iterator = CoolerMerger(clrs, mergebuf=mergebuf, columns=columns, agg=agg)

create(
output_uri,
Expand Down

0 comments on commit 68a90b0

Please sign in to comment.