Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issue in bispectrum method and gradient #293

Open
Hackasteroid142 opened this issue Oct 4, 2023 · 10 comments
Open

Performance issue in bispectrum method and gradient #293

Hackasteroid142 opened this issue Oct 4, 2023 · 10 comments

Comments

@Hackasteroid142
Copy link

  • dask-ms version: 0.2.16
  • Python version: 3.10
  • Operating System: 18.04.6

Description

Hello, some time ago I opened an issue where I asked for help with the implementation of the bispectrum. Taking into account the responses I received, I started developing a gradient and an optimizer for obtaining images using this method. However, I've encountered some issues in terms of memory and time. I've tried to address this by changing the default chunk size in dataset reading, but I've run into problems with shape values.

Currently, I am reading the dataset using the xds_from_ms function from Dask. However, if I change the chunk size to another value, such as 1500, I get the following error.

operands could not be broadcast together with shapes () (6903,2,3,1,2) (6903,1500,3,1,2) (6903,2,3,1,2)

Based on the tests I've conducted, I believe the issue may be related to the adjust_chunks parameter in the blockwise function that I'm using for the bispectrum. However, I've tried editing this parameter to solve the problem, but none of my attempts have been successful. Is there a way to resolve this? Additionally, do you think working with the dataset in this manner and performing the bispectrum calculation like this could lead to performance problems?

Here's an excerpt from the code I'm using. If anything is unclear, I'd be happy to provide additional information, and any help is greatly appreciated. Thank you in advance!

import dask.array as da
import numpy as np
from numba import njit


@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
    cnt = np.zeros(data_shape, dtype=np.int8)
    data_bis = np.ones(data_shape, dtype=data.dtype)

    for row in range(n_row):

        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):
            if (a1 == c[0]) and (a2 == c[1]):
                data_bis[ic, ut, 0] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[1]) and (a2 == c[-1]):
                data_bis[ic, ut, 1] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[0]) and (a2 == c[-1]):
                data_bis[ic, ut, 2] = data[row].conjugate()
                cnt[ic, ut, :] += 1

    return data_bis, cnt


def get_bispectrum(data, ant1, ant2, comb, time, type):

    utime, utime_inv = np.unique(time, return_inverse=True)
    n_utime = utime.size
    n_comb = len(comb)

    n_row, n_chan, n_corr = data.shape
    shape_data = (n_comb, n_utime, 3, n_chan, n_corr)

    bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)

    bis[cnt != 3] = 0

    return bis


input_name = "/home/datasets/ms_name.ms"

ms = xds_from_ms(
            input_name,
            group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
            index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
            chunks={'row': 1500}
        )
comb = itertools.combinations(range(antennas.shape[0]), 3)
 filter_comb = np.array([i for i in comb if antenna_reference in i])
utime_size = np.unique(time).size

bis = da.blockwise(
    get_bispectrum,
    ('triangle', 't', 'p', 'f', 'c'),
    data,
    ('t', 'f', 'c'),
    antenna1, ('t'),
    antenna2, ('t'),
    filter_comb,
    None,
    time.data, ('t'),
    type,
    None,
    align_arrays=False,
    adjust_chunks={'t': (utime_size, )},
    new_axes={
        'triangle': len(filter_comb),
        'p': 3
    },
    dtype=data.dtype
)
@sjperkins
Copy link
Member

Hi @Hackasteroid142. Thanks for the very detailed issue. I think it may be an easy fix. Try changing the following in your blockwise call: adjust_chunks={"t": tuple(utime_size)}.

@Hackasteroid142
Copy link
Author

Hi @sjperkins, thank you for replying. I followed your advice, but unfortunately, it didn't resolve the issue. However, another error has become more frequent than the one I mentioned before, and it is related to adjust_chunks. Here are the details of that error:

Dimension 1 has 56 blocks, adjust_chunks specified with 1 blocks

This error appears when i set a size different of -1 for the chunks when i read the dataset. Do you think there is a error in my implementation? or something is missing?

@JSKenyon
Copy link
Collaborator

Hi @Hackasteroid142. The following code should work. Note that I am using NaN chunk sizes as we do not know the number of unique times in each chunk until the graph executes. There is another way of doing this if you need those chunks to be of known size - let me know if that fits your use-case. Regarding memory usage, the bispectrum result is large - if you hold them all in memory (as is done by calling compute on results below), it will likely cause memory issues on larger datasets. The solution will be implementation specific, but typically involves either writing results to disk so that there is no need to hang onto the result in memory, or doing further processing on the chunks such that they become smaller/unnecessary to hold in memory.

import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools


@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
    cnt = np.zeros(data_shape, dtype=np.int8)
    data_bis = np.ones(data_shape, dtype=data.dtype)

    for row in range(n_row):

        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):
            if (a1 == c[0]) and (a2 == c[1]):
                data_bis[ic, ut, 0] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[1]) and (a2 == c[-1]):
                data_bis[ic, ut, 1] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[0]) and (a2 == c[-1]):
                data_bis[ic, ut, 2] = data[row].conjugate()
                cnt[ic, ut, :] += 1

    return data_bis, cnt


def get_bispectrum(data, ant1, ant2, comb, time, type):

    utime, utime_inv = np.unique(time, return_inverse=True)
    n_utime = utime.size
    n_comb = len(comb)

    n_row, n_chan, n_corr = data.shape
    shape_data = (n_comb, n_utime, 3, n_chan, n_corr)

    bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)

    bis[cnt != 3] = 0

    return bis


if __name__ == "__main__":

    input_name = "path/to/ms"

    xdsl = xds_from_ms(
        input_name,
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        chunks={'row': 10000}
    )
    n_ant = 28
    antenna_reference = 0
    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = np.array([i for i in comb if antenna_reference in i])

    results = []

    for xds in xdsl:
        bis = da.blockwise(
            get_bispectrum,
            ('triangle', 't', 'p', 'f', 'c'),
            xds.DATA.data,
            ('t', 'f', 'c'),
            xds.ANTENNA1.data, ('t'),
            xds.ANTENNA2.data, ('t'),
            filter_comb,
            None,
            xds.TIME.data, ('t'),
            type,
            None,
            align_arrays=False,
            adjust_chunks={'t': (np.nan,)*xds.DATA.data.numblocks[0]},
            new_axes={
                'triangle': len(filter_comb),
                'p': 3
            },
            dtype=xds.DATA.data.dtype
        )
        results.append(bis)

    results = da.compute(results)

@Hackasteroid142
Copy link
Author

Hi @sjperkins, I apologize for the delay in my response. I attempted the solution you provided, but as you mentioned, I need to be aware of the chunk sizes. In later operations, when I use nan values, it results in an error. I also experimented with compute_chunk_sizes() to overcome the error, but it significantly slowed down the operations compared to the normal execution time.

Also, I attempted to modify your solution slightly by using (1, ) * xds.DATA.data.numblocks[0], but in subsequent operations, this approach led to errors such as

non-broadcastable output operand with shape (6903,1,3,1,2) doesn't match the broadcast shape (6903,3,3,1,2)

@JSKenyon
Copy link
Collaborator

The following will make sure that the chunks on the bispectrum are correct. However, this will not guarantee that all rows associated with a single time will be in the same chunk. That is also possible but requires a fair amount of additional code. You can see an example of the approach here.

import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools


@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
    cnt = np.zeros(data_shape, dtype=np.int8)
    data_bis = np.ones(data_shape, dtype=data.dtype)

    for row in range(n_row):

        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):
            if (a1 == c[0]) and (a2 == c[1]):
                data_bis[ic, ut, 0] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[1]) and (a2 == c[-1]):
                data_bis[ic, ut, 1] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[0]) and (a2 == c[-1]):
                data_bis[ic, ut, 2] = data[row].conjugate()
                cnt[ic, ut, :] += 1

    return data_bis, cnt


def get_bispectrum(data, ant1, ant2, comb, time, type):

    utime, utime_inv = np.unique(time, return_inverse=True)
    n_utime = utime.size
    n_comb = len(comb)

    n_row, n_chan, n_corr = data.shape
    shape_data = (n_comb, n_utime, 3, n_chan, n_corr)

    bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)

    bis[cnt != 3] = 0

    return bis


def compute_utime_chunks(xdsl):

    utime_chunks = []

    for xds in xdsl:
        time = xds.TIME.data
        utimes_per_chunk = da.blockwise(
            lambda t: np.array(len(np.unique(t))), "t",
            time, "t",
            adjust_chunks={'t': (np.nan,)*time.numblocks[0]}
        )
        utime_chunks.append(utimes_per_chunk)

    return [tuple(utc) for utc in da.compute(utime_chunks)[0]]


if __name__ == "__main__":

    input_name = "path/to/ms"

    xdsl = xds_from_ms(
        input_name,
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        chunks={'row': 10000}
    )
    n_ant = 28
    antenna_reference = 0
    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = np.array([i for i in comb if antenna_reference in i])

    utime_chunks = compute_utime_chunks(xdsl)

    results = []

    for xds, utc in zip(xdsl, utime_chunks):
        bis = da.blockwise(
            get_bispectrum,
            ('triangle', 't', 'p', 'f', 'c'),
            xds.DATA.data,
            ('t', 'f', 'c'),
            xds.ANTENNA1.data, ('t'),
            xds.ANTENNA2.data, ('t'),
            filter_comb,
            None,
            xds.TIME.data, ('t'),
            type,
            None,
            align_arrays=False,
            adjust_chunks={'t': utc},
            new_axes={
                'triangle': len(filter_comb),
                'p': 3
            },
            dtype=xds.DATA.data.dtype
        )
        results.append(bis)

    results = da.compute(results)

@JSKenyon
Copy link
Collaborator

@Hackasteroid142 I took the liberty of adding the chunking functionality as it may be useful to other users. Here is the example, which now allows chunking by unique time.

import dask.array as da
import numpy as np
from numba import njit
from daskms import xds_from_ms
import itertools


@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
    cnt = np.zeros(data_shape, dtype=np.int8)
    data_bis = np.ones(data_shape, dtype=data.dtype)

    for row in range(n_row):

        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):
            if (a1 == c[0]) and (a2 == c[1]):
                data_bis[ic, ut, 0] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[1]) and (a2 == c[-1]):
                data_bis[ic, ut, 1] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[0]) and (a2 == c[-1]):
                data_bis[ic, ut, 2] = data[row].conjugate()
                cnt[ic, ut, :] += 1

    return data_bis, cnt


def get_bispectrum(data, ant1, ant2, comb, time, type):

    utime, utime_inv = np.unique(time, return_inverse=True)
    n_utime = utime.size
    n_comb = len(comb)

    n_row, n_chan, n_corr = data.shape
    shape_data = (n_comb, n_utime, 3, n_chan, n_corr)

    bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)

    bis[cnt != 3] = 0

    return bis


def utime_and_row_chunks(time, req_utime=1):
    """Internals of compute_utime_and_row_chunks."""

    utime, utime_counts = np.unique(time, return_counts=True)
    n_utime = utime.size
    req_utime = req_utime or n_utime  # Catch zero.

    chunk_starts = np.arange(0, n_utime, req_utime)

    utime_chunks = np.array(
        [
            req_utime if i + req_utime < n_utime else n_utime - i
            for i in chunk_starts
        ]
    )

    row_chunks = np.add.reduceat(utime_counts, chunk_starts)

    return np.stack([utime_chunks, row_chunks], axis=0)



def compute_utime_and_row_chunks(indexing_xdsl, req_utime=1):
    """Figure out the chunking in unique time and row. Triggers compute."""

    chunking = []

    for xds in indexing_xdsl:

        chunking.append(
            xds.TIME.data.map_blocks(
                utime_and_row_chunks,
                req_utime,
                chunks=((2,), (np.nan,)),
                new_axis=0,
                dtype=int
            )
        )

    result = da.compute(chunking)[0]

    utime_chunks = [tuple(arr[0]) for arr in result]
    row_chunks = [tuple(arr[1]) for arr in result]

    return utime_chunks, row_chunks


if __name__ == "__main__":

    input_name = "~/reductions/3C147/msdir/C147_unflagged.MS"

    # Set up TIME only datasets which we can use to establish chunking. Note
    # that we use only a single chunk per dataset.
    indexing_xdsl = xds_from_ms(
        input_name,
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        columns=("TIME",),
        chunks={'row': -1}
    )

    # req_utime controls the number of unique times that you want in a
    # single chunk. This triggers some early but very lightweight compute
    # to figure out the required chunking - much cheaper than
    # compute_chunk_sizes.
    utime_chunks_list, row_chunks_list = compute_utime_and_row_chunks(
        indexing_xdsl, req_utime=30
    )

    # Now that we know the desired chunking, load the data we want to
    # manipulate with the required chunking.
    xdsl = xds_from_ms(
        input_name,
        group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
        index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
        columns=("TIME", "ANTENNA1", "ANTENNA2", "DATA"),
        chunks=[{'row': rcs} for rcs in row_chunks_list]
    )

    n_ant = 28
    antenna_reference = 0
    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = np.array([i for i in comb if antenna_reference in i])

    results = []

    for xds, utc in zip(xdsl, utime_chunks_list):
        bis = da.blockwise(
            get_bispectrum,
            ('triangle', 't', 'p', 'f', 'c'),
            xds.DATA.data,
            ('t', 'f', 'c'),
            xds.ANTENNA1.data, ('t'),
            xds.ANTENNA2.data, ('t'),
            filter_comb,
            None,
            xds.TIME.data, ('t'),
            type,
            None,
            align_arrays=False,
            adjust_chunks={'t': utc},
            new_axes={
                'triangle': len(filter_comb),
                'p': 3
            },
            dtype=xds.DATA.data.dtype
        )
        results.append(bis)

    results = da.compute(results)

@Hackasteroid142
Copy link
Author

Thank you for the previous code; it was very helpful for working with different chunk sizes in the Bispectrum. However, I've been trying to use this to improve the performance of my Bispectrum method and the optimization method I'm working on. Despite implementing the code you provided, I haven't seen any improvement in time or memory usage. I'm working with an extract of the HD163296 dataset and a dirty image of this extract, which are available at this link: https://we.tl/t-R4nVGgd2vN. The image is used to calculate the model visibilities for the dataset.

In the optimization method, I use the gradient of the chi-square Bispectrum as shown in this article (Equation 3 of the appendix) by Andrew A. Chael et al. Based on this code and the dataset I mentioned earlier, is there something I should take into consideration? Is there any way to further optimize my code?

As background information, I conducted a small experiment to measure the time it takes to compute the gradient. It takes a few minutes to execute the gradient function thanks to Dask, but if I compute the result, it can take up to 2 days. Also, in the code, there is a class named "Mask" that aims to calculate only a portion of the image to improve execution time. It's like a matrix with zeros around it, representing the areas where calculations are not needed. I hope I have explained myself clearly, and the code is understandable. If you have any questions, I'd be happy to answer them.

import astropy.units as un
import dask.array as da
import numpy as np
from numba import njit, prange    
    
def gradient(dataset, mask):

    if mask is None:
        dchi2_1d = __gradient_no_mask(dataset)
    else:
        dchi2_1d = __gradient_mask(dataset, mask=mask)

    _grad_value = gradient_image_reconstruction(dchi2_1d, mask)

    return _grad_value

def gradient_image_reconstruction(image, mask):
    # Array full with zeros to assign only valid indices with the values of the gradient
    dchi2_1d = da.zeros((np.prod(image.data.shape), ), dtype=np.float32)
    # Get ravel indices of the mask where its values are True
    ravel_mask_idx = da.where(mask.data.data.ravel())[0]
    ravel_mask_idx.compute_chunk_sizes()
    # In the array, assign only valid indices (ravel_mask_idx) with the gradient
    dchi2_1d[(ravel_mask_idx, )] = image
    # Reshape gradient (transposed because indices are in F order)
    image_2d = dchi2_1d.reshape_grad_value.shape).T
    # Flip needed from fft
    flipped = da.flip(image_2d, axis=[0, 1])
    return flipped.rechunk(_grad_value.chunksize)

def __gradient_no_mask(dataset):
    delta_x, delta_y = image.cellsize.to(un.rad).value

    x_ind_2d, y_ind_2d, *z_ind_2d = da.indices(
        image.data.data.shape, dtype=np.int32, chunks=image.data.data.chunksize
    )

    x_ind = x_ind_2d.ravel()
    y_ind = y_ind_2d.ravel()

    x_cell = __cell_index_delta(x_ind, delta_x, np.float32)
    y_cell = __cell_index_delta(y_ind, delta_y, np.float32)

    dchi2_1d = da.zeros_like(x_cell, dtype=np.float32, chunks=x_cell.chunks)
    for i, ms in enumerate(dataset.ms_list):
        dchi2_1d += __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y)

    return dchi2_1d

def __gradient_mask(dataset, mask):

    delta_x, delta_y = image.cellsize.to(un.rad).value

    x_ind, y_ind, *z_ind = mask.indices

    x_cell = __cell_index_delta(x_ind, delta_x, np.float32)
    y_cell = __cell_index_delta(y_ind, delta_y, np.float32)

    dchi2_1d_masked = da.zeros_like(x_cell, dtype=np.float32, chunks=x_cell.chunks)
    for i, ms in enumerate(dataset.ms_list):
        dchi2_1d_masked += __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y)

    return dchi2_1d_masked

def __ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, *, mask):
    bis_obs = ms.visibilities.cal_data  # (ncomb, ntime, nchans, ncorrs)
    bis_model = ms.visibilities.cal_model
    bis_weight = ms.visibilities.cal_weight

    bis_r = bis_obs - bis_model  # Calculando visibilidad residuo
    vis = ms.visibilities.bis_data.data  # (ncomb, ntime, nant, nchans, ncorrs)
    uvw = ms.visibilities.bis_uvw.data.astype(np.float32) * un.m  # (ncomb, ntime, nant, uvw)

    pol_id = ms.polarization_id
    corr_names = dataset.polarization.corrs_string[pol_id]
    corr_idx = [x in _corr for x in corr_names]
    bis_weight = bis_weight[:, :, :, corr_idx]  # Filter by correlation
    bis_r = bis_r[:, :, :, corr_idx]  # Filter by correlation

    spw_id = ms.spw_id
    nchans = dataset.spws.nchans[spw_id]
    chans = (dataset.spws.dataset[spw_id].CHAN_FREQ.data.squeeze(axis=0) *
                un.Hz).rechunk(bis_obs.chunksize[-2])

    uvw_lambdas = _uvw_lambdas(uvw, chans, nchans)

    uv_lambdas = uvw_lambdas[:, :, :, :, :2]
    uv_lambdas = uv_lambdas.map_blocks(
        lambda x: x.value if isinstance(x, un.Quantity) else x, dtype=np.float32
    )

    phase_dirs_x = (image.data.shape[0] // 2) * delta_x
    phase_dirs_y = (image.data.shape[1] // 2) * delta_y

    x = x_cell - np.float32(phase_dirs_x)
    y = y_cell - np.float32(phase_dirs_y)

    chans = chans.map_blocks(lambda x: x.value if isinstance(x, un.Quantity) else x)
    beam = __primary_beam(
        chans,
        image.data.shape,
        chunks=mask.data.chunks if mask is not None else image.data.data.chunksize
    )

    return _array_gradient(
        x, y, uv_lambdas, bis_weight, vis, bis_r, beam, bis_model, mask=mask
    )

def _uvw_lambdas(uvw, chans, nchans):

    chans_broadcast = chans[np.newaxis, :, np.newaxis]
    uvw_broadcast = da.repeat(uvw[:, :, :, np.newaxis, :], nchans, axis=3)
    uvw_lambdas = array_unit_conversion(
        array=uvw_broadcast,
        unit=un.lambdas,
        equivalencies=lambdas_equivalencies(restfreq=chans_broadcast),
    )
    return uvw_lambdas

def _array_gradient(x, y, uv, w, vis, vr, pb, bm, *, mask):

    data_br = da.blockwise(
        _block_gradient,
        ("ncomb", "chan", "corr", "idx"),
        x,
        ("idx", ),
        y,
        ("idx", ),
        uv,
        ("ncomb", "nutime", "nant", "chan", "corr"),
        w,
        ("ncomb", "nutime", "chan", "corr"),
        vis,
        ("ncomb", "nutime", "nant", "chan", "corr"),
        vr,
        ("ncomb", "nutime", "chan", "corr"),
        bm,
        ("ncomb", "nutime", "chan", "corr"),
        adjust_chunks={
            "ncomb": 1,
            "corr": 1
        },
        dtype=np.float64,
    )

    data = data_br.sum(axis=(0, 2))

    if mask is not None:
        # Broadcast mask to match the shape of the PB, and obtain the masked primary beam values
        # via bool indexing
        broadcasted_mask = da.broadcast_to(mask.data.data, pb.shape)
        # Get the Primary Beam masked values
        pb_fitted = pb[broadcasted_mask]
        pb_fitted.compute_chunk_sizes()
    else:
        pb_fitted = pb  # If there is no mask, there is no need to filter the primary beam

    # Reshape to 2-d where the second dim is the cell idx and the first dim is the channel
    # (frequency) dim
    pb_2d = pb_fitted.reshape((-1, x.shape[0]))

    data = da.einsum('in,in->n', data, pb_2d)

    return data

def _block_gradient(x, y, uv, w, vis, vr, bm):
    uv = uv[0][0].astype(np.float64)
    dchi2_broad = _ms_memory_gradient(x, y, uv, w[0], vis[0][0], vr[0], bm[0])
    return dchi2_broad[None, :, None, :]

@staticmethod
@njit(nogil=True, cache=True, parallel=True)
def _ms_memory_gradient(x, y, uv, w, vis, vr, bm):
    nidx = x.shape[0]
    out_dtype = vis.real.dtype

    ncomb, utime, nant, nchan, ncorr = vis.shape

    ms_gradient = np.zeros((nchan, nidx), dtype=out_dtype)
    for i in prange(nidx):
        for c in range(ncomb):
            for t in range(utime):
                for a in range(nant):
                    for f in range(nchan):
                        u, v = uv[c, t, a, f]
                        uv_r = u * x[i] + v * y[i]
                        uv_r *= 2 * np.pi * 1j
                        a_ij = np.exp(uv_r)

                        for r in range(ncorr):
                            vm = vis[c, t, a, f, r].conjugate()
                            ed = a_ij / vm if vm != 0 else 0

                            i_sum = vr[c, t, f, r] * bm[c, t, f, r].conjugate() * ed

                            s_sum = -w[c, t, f, r] * i_sum.real

                            data = s_sum / ncomb

                            ms_gradient[f, i] += data
    return ms_gradient

def __cell_index_delta(index_array, delta, dtype=np.float32):
    """
    Scale an array by a delta, so that each cell/pixel has the increment in radians specified
    by delta.
    """
    return (index_array * delta).astype(dtype)

def __primary_beam(chans, shape, antenna=np.array([0]), chunks="auto"):
    """
    Get the primary beam image for every frequency.
    """
    beam = dataset.antenna.primary_beam.beam(
        chans, shape, antenna=antenna, imchunks=chunks
    )
    beam = beam[0].astype(np.float32)  # temporal indexing
    return beam

input_name = "/home/datasets/ms_name.ms"

ms = xds_from_ms(
            input_name,
            group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
            index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
            chunks={'row': 1500}
        )

mask = Mask()

res = gradient(ms, mask)

@JSKenyon
Copy link
Collaborator

JSKenyon commented Oct 18, 2023

Hi @Hackasteroid142! This problem is now beginning to enter the more general territory of optimising dask code. Doing so requires in depth understanding of dask graphs (https://docs.dask.org/en/latest/graphs.html). A good starting point is to inspect the graph (https://docs.dask.org/en/stable/graphviz.html). This can show you whether your implementation contains problematic many-to-many mappings which dask is known to struggle with.

Regarding performance, be sure to remember that dask does few/no computations until compute is called i.e. graph construction should be fast. Most performance problems will manifest after calling compute.

In general, while it is useful to see your code, I would encourage you to post code which can be run. This makes it much easier for us to provide useful feedback. As it stands, I cannot give you more precise advice.

@Hackasteroid142
Copy link
Author

Hello, I've been reviewing what you advised me, but I haven't achieved good results. I still have a very high runtime. As you instructed, here's a functional code that includes the chi-squared bispectrum, and I've also added the necessary components to calculate the bispectrum and the code you give me for the chunks. Any advice or help to improve my code would be greatly appreciated. If there's any issue with the code, I'll be attentive to help.

from daskms import xds_from_ms
from numba import njit, prange 
import astropy.units as un
import dask.array as da
import numpy as np
import itertools

def gradient(dataset, imageshape):
    delta_x, delta_y = [-2.42406841e-08,  2.42406841e-08]
    x_ind_2d, y_ind_2d, *z_ind_2d = da.indices(
            imageshape, dtype=np.int32
    )

    x_ind = x_ind_2d.ravel()
    y_ind = y_ind_2d.ravel()

    x_cell = x_ind * delta_x
    y_cell = y_ind * delta_y

    grad = da.zeros_like(x_cell, dtype=np.float32)
    for i, ms in enumerate(dataset):
        grad += ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, imageshape)

    grad_2d = grad.reshape(imageshape).T
    image = da.flip(grad_2d, axis=[0, 1])
    return image

def ms_gradient(ms, x_cell, y_cell, delta_x, delta_y, imageshape):
    bis_obs = ms.CAL_DATA.data  # (ncomb, ntime, nchans, ncorrs)
    bis_model = ms.CAL_MODEL.data
    bis_weight = ms.CAL_WEIGHT.data

    bis_r = bis_obs - bis_model 
    vis = ms.BIS_DATA.data  # (ncomb, ntime, nant, nchans, ncorrs)
    uvw = ms.BIS_UVW.data.astype(np.float32) * un.m  # (ncomb, ntime, nant, uvw)

    nchans = ms.DATA.shape[1]
    uvw_lambdas = da.repeat(uvw[:, :, :, np.newaxis, :], nchans, axis=3)

    uv_lambdas = uvw_lambdas[:, :, :, :, :2]
    uv_lambdas = uv_lambdas.map_blocks(
        lambda x: x.value if isinstance(x, un.Quantity) else x, dtype=np.float32
    )

    phase_dirs_x = (imageshape[0] // 2) * delta_x
    phase_dirs_y = (imageshape[1] // 2) * delta_y

    x = x_cell - np.float32(phase_dirs_x)
    y = y_cell - np.float32(phase_dirs_y)

   
    data_br = da.blockwise(
        _block_gradient,
        ("ncomb", "chan", "corr", "idx"),
        x,
        ("idx", ),
        y,
        ("idx", ),
        uv_lambdas,
        ("ncomb", "nutime", "nant", "chan", "corr"),
        bis_weight,
        ("ncomb", "nutime", "chan", "corr"),
        vis,
        ("ncomb", "nutime", "nant", "chan", "corr"),
        bis_r,
        ("ncomb", "nutime", "chan", "corr"),
        bis_model,
        ("ncomb", "nutime", "chan", "corr"),
        dtype=np.float64,
    )

    data = data_br.sum(axis=(0, 2))

    pb = da.ones_like(data)

    data = da.einsum('in,in->n', data, pb)

    return data

def _block_gradient(x, y, uv, w, vis, vr, bm):
    uv = uv[0][0].astype(np.float64)
    dchi2_broad = _ms_memory_gradient(x, y, uv, w[0], vis[0][0], vr[0], bm[0])
    return dchi2_broad[None, :, None, :]

@staticmethod
@njit(nogil=True, cache=True, parallel=True)
def _ms_memory_gradient(x, y, uv, w, vis, vr, bm):
    nidx = x.shape[0]
    out_dtype = vis.real.dtype

    ncomb, utime, nant, nchan, ncorr = vis.shape

    ms_gradient = np.zeros((nchan, nidx), dtype=out_dtype)
    for i in prange(nidx):
        for c in range(ncomb):
            for t in range(utime):
                for a in range(nant):
                    for f in range(nchan):
                        u, v = uv[c, t, a, f]
                        uv_r = u * x[i] + v * y[i]
                        uv_r *= 2 * np.pi * 1j
                        a_ij = np.exp(uv_r)

                        for r in range(ncorr):
                            vm = vis[c, t, a, f, r].conjugate()
                            ed = a_ij / vm if vm != 0 else 0

                            i_sum = vr[c, t, f, r] * bm[c, t, f, r].conjugate() * ed

                            s_sum = -w[c, t, f, r] * i_sum.real

                            data = s_sum / ncomb

                            ms_gradient[f, i] += data
    return ms_gradient


def utime_and_row_chunks(time, req_utime=1):
    """Internals of compute_utime_and_row_chunks."""

    utime, utime_counts = np.unique(time, return_counts=True)
    n_utime = utime.size
    req_utime = req_utime or n_utime  # Catch zero.

    chunk_starts = np.arange(0, n_utime, req_utime)

    utime_chunks = np.array(
        [
            req_utime if i + req_utime < n_utime else n_utime - i
            for i in chunk_starts
        ]
    )

    row_chunks = np.add.reduceat(utime_counts, chunk_starts)

    return np.stack([utime_chunks, row_chunks], axis=0)



def compute_utime_and_row_chunks(indexing_xdsl, req_utime=1):
    """Figure out the chunking in unique time and row. Triggers compute."""

    chunking = []

    for xds in indexing_xdsl:

        chunking.append(
            xds.TIME.data.map_blocks(
                utime_and_row_chunks,
                req_utime,
                chunks=((2,), (np.nan,)),
                new_axis=0,
                dtype=int
            )
        )

    result = da.compute(chunking)[0]

    utime_chunks = [tuple(arr[0]) for arr in result]
    row_chunks = [tuple(arr[1]) for arr in result]

    return utime_chunks, row_chunks

@njit(cache=True, nogil=True)
def get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, data_shape):
    cnt = np.zeros(data_shape, dtype=np.int8)
    data_bis = np.ones(data_shape, dtype=data.dtype)

    for row in range(n_row):

        ut = utime_inv[row]

        a1 = ant1[row]
        a2 = ant2[row]

        for ic, c in enumerate(comb):
            if (a1 == c[0]) and (a2 == c[1]):
                data_bis[ic, ut, 0] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[1]) and (a2 == c[-1]):
                data_bis[ic, ut, 1] = data[row]
                cnt[ic, ut, :] += 1
            elif (a1 == c[0]) and (a2 == c[-1]):
                data_bis[ic, ut, 2] = data[row].conjugate()
                cnt[ic, ut, :] += 1

    return data_bis, cnt

def get_bispectrum(data, ant1, ant2, comb, time, type):

    utime, utime_inv = np.unique(time, return_inverse=True)
    n_utime = utime.size
    n_comb = len(comb)

    if type == 'UVW':
        n_row, n_id = data.shape
        shape_data = (n_comb, n_utime, 3, n_id)
    else:
        n_row, n_chan, n_corr = data.shape
        shape_data = (n_comb, n_utime, 3, n_chan, n_corr)

    bis, cnt = get_bispectrum_data(data, ant1, ant2, comb, utime_inv, n_row, shape_data)

    bis[cnt != 3] = 0

    return bis

def bispectrum_data(data, antenna1, antenna2, filter_comb, time, type, utime_chunks_list):
    utime_size = np.unique(time).size

    if type == 'UVW':
        bis_shape = ('triangle', 't', 'p', 'id')
        data_shape = ('t', 'id')
    else:
        bis_shape = ('triangle', 't', 'p', 'f', 'c')
        data_shape = ('t', 'f', 'c')

    bis = da.blockwise(
        get_bispectrum,
        bis_shape,
        data,
        data_shape,
        antenna1, ('t'),
        antenna2, ('t'),
        filter_comb,
        None,
        time.data, ('t'),
        type,
        None,
        adjust_chunks={'t': utime_chunks_list},
        new_axes={
            'triangle': len(filter_comb),
            'p': 3
        },
        dtype=data.dtype
    )

    return bis

def bispectrum(dataset, utime_chunks_list):
    n_ant = 28
    antenna_reference = 0
    comb = itertools.combinations(range(n_ant), 3)
    filter_comb = np.array([i for i in comb if antenna_reference in i])    

    for i, xds in enumerate(dataset):
        data = xds.DATA.data
        model = xds.MODEL.data
        weight = xds.WEIGHT.data
        time = xds.TIME
        antenna1 = xds.ANTENNA1.data
        antenna2 = xds.ANTENNA2.data
        uvw = xds.UVW.data

        flags = xds.FLAG.data

        data = data * ~flags
        model = model * ~flags
        weight = weight * ~flags

        bm = bispectrum_data(model, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])
        uvw_bis = bispectrum_data(uvw, antenna1, antenna2, filter_comb, time, 'UVW', utime_chunks_list[i])
        bw = bispectrum_data(weight, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])
        bo = bispectrum_data(data, antenna1, antenna2, filter_comb, time, 'DATA', utime_chunks_list[i])

        cal_data = da.prod(bo, axis=2)
        cal_model = da.prod(bm, axis=2)

        weighted_vis = (da.absolute(bm)**2) * bw
        aux = da.divide(
            1, weighted_vis, out=da.zeros_like(weighted_vis), where=weighted_vis != 0
        )
        bis_sigma_squared = (da.absolute(cal_data)**2) * da.sum(aux, axis=2)
        bis_weight = da.divide(
            1,
            bis_sigma_squared,
            out=da.zeros_like(bis_sigma_squared),
            where=bis_sigma_squared != 0
        )

        uvw_bis[:, :, 2] *= -1

        dataset[i] = xds.assign({
            "BIS_MODEL": (('triangle', 't', 'p', 'f', 'c'), bm),
            "BIS_DATA": (('triangle', 't', 'p', 'f', 'c'), bo),
            "BIS_WEIGHT": (('triangle', 't', 'p', 'f', 'c'), bw),
            "BIS_UVW" : (('triangle', 't', 'p', 'id'), uvw_bis),
            "CAL_DATA": (('triangle', 't', 'f', 'c'), cal_data),
            "CAL_MODEL": (('triangle', 't', 'f', 'c'), cal_model),
            "CAL_WEIGHT": (('triangle', 't', 'f', 'c'), bis_weight),
        })

    return dataset


if __name__ == "__main__":
  
    input_name = "/home/datasets/ms_name.ms"

    dataset = xds_from_ms(
                input_name,
                group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
                index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
                chunks={'row': -1}
            )

    utime_chunks_list, row_chunks_list = compute_utime_and_row_chunks(
            dataset, req_utime=30
        )

    xdsl = xds_from_ms(
            input_name,
            group_cols=("SCAN_NUMBER", "FIELD_ID", "DATA_DESC_ID"),
            index_cols=("TIME", "ANTENNA1", "ANTENNA2"),
            chunks=[{'row': rcs} for rcs in row_chunks_list]
        )

    xdsl = [
        xds.assign(
            {
                "MODEL": (
                    xds.DATA.dims,
                    da.ones(
                        xds.DATA.data.shape,
                        dtype=np.complex64,
                        chunks=xds.DATA.data.chunksize
                    ) / 2
                ), 
                "WEIGHT": (
                    xds.DATA.dims,
                    da.tile(xds.WEIGHT.data, xds.DATA.shape[1]).reshape(len(xds.WEIGHT.data),xds.DATA.shape[1], xds.DATA.shape[2])
                )
            }
        ) for xds in xdsl
    ] 

    dataset_bis = bispectrum(xdsl, utime_chunks_list)
    grad = gradient(dataset_bis, (512,512))

@sjperkins
Copy link
Member

Hi @Hackasteroid142. I had an offline chat with @JSKenyon regarding the above code. Running it seems to take all available cores, so without fully understanding the optimality of the code for the individual chunks, the dask part seems to be doing its job. Additionally, it seems to be some variant of the DFT which is always going to be much slower than an FFT.

For example, https://github.com/caracal-pipeline/crystalball does a DFT predict of a WSClean sky model using dask

This can take up to a week for large Measurement Sets and sources.

DFT's tend to be embarrasingly parallel, so your algorithm will probably benefit from distributing the problem on a compute cluster. Unfortunately we don't have the resources to help set this up and debug -- I suggest you consult some of the dask resources available on the internet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants