Skip to content

Commit

Permalink
use sample data for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Apr 5, 2024
1 parent e15f494 commit 2128d09
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 35 deletions.
77 changes: 44 additions & 33 deletions examples/convert_tiffs_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,67 @@
from stack_to_chunk._io_helpers import _load_env_var_as_path

OVERWRITE_EXISTING_ZARR = True
USE_SAMPLE_DATA = True


# Paths to the Google Drive folder containing tiffs for all subjects & channels
# and the output folder for the zarr files (both set as environment variables)
input_dir = _load_env_var_as_path("ATLAS_PROJECT_TIFF_INPUT_DIR")
output_dir = _load_env_var_as_path("ATLAS_PROJECT_ZARR_OUTPUT_DIR")
# Chunk size for the zarr file
chunk_size = 16

if USE_SAMPLE_DATA:
from stack_to_chunk.sample_data import SampleDaskStack

# Define subject ID and check that the corresponding folder exists
subject_id = "topro35"
assert (input_dir / subject_id).is_dir()
cat_data = SampleDaskStack(output_dir / "sample_data", n_slices=128)
cat_data.get_images()

# Define channel (by wavelength) and check that there is exactly one folder
# containing the tiff files for this channel in the subject folder
channel = "488"
channel_dirs = sorted(input_dir.glob(f"{subject_id}/*{channel}*"))
assert len(channel_dirs) == 1
channel_dir = channel_dirs[0]
else:
# Define subject ID and check that the corresponding folder exists
subject_id = "topro35"
assert (input_dir / subject_id).is_dir()

# Select chunk size
chunk_size = 64
# Define channel (by wavelength) and check that there is exactly one folder
# containing the tiff files for this channel in the subject folder
channel = "488"
channel_dirs = sorted(input_dir.glob(f"{subject_id}/*{channel}*"))
assert len(channel_dirs) == 1
channel_dir = channel_dirs[0]


if __name__ == "__main__":
# Create a folders for the subject and channel in the output directory
zarr_file_path = output_dir / subject_id / f"{subject_id}_{channel}.zarr"

# Create a MultiScaleGroup object (zarr group)
zarr_group = MultiScaleGroup(
zarr_file_path,
name=f"{subject_id}_{channel}",
spatial_unit="micrometer",
voxel_size=(3, 3, 3),
)

# Read the tiff stack into a dask array
# Passing only the first tiff is enough
# (because the rest of the stack is refererenced in metadata)
tiff_files = sorted(channel_dir.glob("*.tif"))
da_arr = dask_image.imread.imread(tiff_files[0]).T
logger.info(
f"Read tiff stack into Dask array with shape {da_arr.shape}, "
f"dtype {da_arr.dtype}, and chunk sizes {da_arr.chunksize}"
)
if USE_SAMPLE_DATA:
zarr_file_path = cat_data.zarr_file_path
da_arr = cat_data.generate_stack()
else:
# Create a folders for the subject and channel in the output directory
zarr_file_path = output_dir / subject_id / f"{subject_id}_{channel}.zarr"

# Read the tiff stack into a dask array
# Passing only the first tiff is enough
# (because the rest of the stack is refererenced in metadata)
tiff_files = sorted(channel_dir.glob("*.tif"))
da_arr = dask_image.imread.imread(tiff_files[0]).T
logger.info(
f"Read tiff stack into Dask array with shape {da_arr.shape}, "
f"dtype {da_arr.dtype}, and chunk sizes {da_arr.chunksize}"
)

# Delete existing zarr file if it exists and we want to overwrite it
if OVERWRITE_EXISTING_ZARR and zarr_file_path.exists():
logger.info(f"Deleting existing {zarr_file_path}")
shutil.rmtree(zarr_file_path)

# Add full resolution data to zarr group 0
if OVERWRITE_EXISTING_ZARR or not zarr_file_path.exists():
# Create a MultiScaleGroup object (zarr group)
zarr_group = MultiScaleGroup(
zarr_file_path,
name="stack",
spatial_unit="micrometer",
voxel_size=(3, 3, 3),
)
# Add full resolution data to zarr group 0
zarr_group.add_full_res_data(
da_arr,
n_processes=5,
Expand All @@ -73,5 +84,5 @@

# Add downsampled levels
# Each level corresponds to downsampling by a factor of 2**i
for i in range(1, 3):
for i in [1]:
zarr_group.add_downsample_level(i)
41 changes: 39 additions & 2 deletions src/stack_to_chunk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger
from numcodecs import blosc
from numcodecs.abc import Codec
from scipy.ndimage import zoom

from stack_to_chunk._array_helpers import _copy_slab
from stack_to_chunk.ome_ngff import SPATIAL_UNIT
Expand Down Expand Up @@ -191,18 +192,54 @@ def add_downsample_level(self, level: int) -> None:
msg = f"Level {level_str} already found in zarr group"
raise RuntimeError(msg)

if (level_minus_one := str(int(level) - 1)) not in self._group:
level_minus_one = str(int(level) - 1)
if level_minus_one not in self._group:
msg = f"Level below (level={level_minus_one}) not present in group."
raise RuntimeError(
msg,
)

logger.info(f"Downsampling level {level_minus_one} to level {level_str}...")
# Create the new level in the zarr group.
source_data = self._group[level_minus_one]
new_shape = np.array(source_data.shape) // 2

self._group[level_str] = zarr.create(
new_shape,
chunks=source_data.chunks,
dtype=source_data.dtype,
compressor=source_data.compressor,
)

# Lazily take each dask chunk as a block and downsample it.

@staticmethod
def downsample_slab(
slab: np.ndarray, factor: int = 2, order: int = 1
) -> np.ndarray:
"""
Downsample a single chunk of data using linear interpolation.
Parameters
----------
chunk : numpy.ndarray
The chunk of data to downsample.
factor : int, optional
The downsampling factor, by default 2.
order : int, optional
The order of the spline interpolation, by default 1 (linear).
Returns
-------
numpy.ndarray
The downsampled chunk.
Notes
-----
This function uses ``scipy.ndimage.zoom`` to perform the downsampling.
"""
new_shape = np.maximum(1, np.array(slab.shape) // factor)
if np.any(new_shape < 1):
logger.warning("Chunk too small to downsample. Returning original.")
return slab
return zoom(slab, zoom=1 / factor, order=order, mode="nearest")
70 changes: 70 additions & 0 deletions src/stack_to_chunk/sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Module to generate sample data for testing purposes."""

from pathlib import Path

import dask_image.imread
import skimage.color
import skimage.data
import tifffile
from dask.array.core import Array
from loguru import logger


class SampleDaskStack:
"""
Generate a sample 3D Dask stack for testing purposes.
"""

def __init__(self, data_dir: Path, n_slices: int = 135) -> None:
"""
Generate a sample Dask stack for testing purposes.
Parameters
----------
data_dir : Path
Directory to save the sample stack.
The 2D images will be saved in a subdirectory called "slices".
The output zarr file will be saved in the root of the ``data_dir``
as "stack.zarr".
n_slices : int
Number of slices to generate.
"""
self.data_dir = data_dir
self.data_dir.mkdir(exist_ok=True)

self.n_slices = n_slices

self.slice_dir = self.data_dir / "slices"
self.slice_dir.mkdir(exist_ok=True)

self.zarr_file_path = self.data_dir / "stack.zarr"

def get_images(self) -> None:
"""
Download the cat image write multiple 2D images to the slice directory.
"""
# Check how many images are already written
existing_images = list(self.slice_dir.glob("*.tif"))
n_existing_images = len(existing_images)
n_missing_images = self.n_slices - n_existing_images
logger.info(f"Found {n_existing_images} existing images in {self.slice_dir}")

# Download the cat image and write missing images (if any)
logger.info(f"Generating {n_missing_images} missing images...")
if n_missing_images > 0:
data_2d = skimage.color.rgb2gray(skimage.data.cat())

for i in range(n_existing_images, self.n_slices):
tifffile.imwrite(self.slice_dir / f"{str(i).zfill(3)}.tif", data_2d)

def generate_stack(self) -> Array:
"""
Generate a 3D Dask stack from the 2D images in the slice directory.
"""
stack = dask_image.imread.imread(str(self.slice_dir / "*.tif")).T
logger.info(
f"Read {stack.shape[2]} images into Dask array "
f"with shape {stack.shape}, and chunk sizes {stack.chunksize}"
)
return stack

0 comments on commit 2128d09

Please sign in to comment.