Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stratified K-fold options #182

Open
iancze opened this issue Apr 3, 2023 · 0 comments
Open

Implement stratified K-fold options #182

iancze opened this issue Apr 3, 2023 · 0 comments

Comments

@iancze
Copy link
Collaborator

iancze commented Apr 3, 2023

This feature was originally drafted by @hgrzy in PR #93. The codebase has evolved significantly since that PR was opened and our thinking about how to run cross validation loops has evolved slightly. I'm closing PR #93 and excerpting the relevant draft code below. The hope is that this new issue could serve as a starting point for future stratified K-fold implementations.

In the original PR, Hannah said,

"Perform k fold cross validation where each k fold has the same ratio of low : high spatial frequency data points as the overall dataset.

Currently working on the for loop. I'm skeptical with the number of rows in the np array "pairs." While iterating through the while loop in the init of a stratkfold object it stops producing any flags that I've incorporated though jupyter notebook says it is still running the cell."

This most likely ties into the GriddedDataset and crossval discussions in #163 #166, too.

class StratKCV:
    def __init__(self, gridder, griddedDataset, k, npseed=None):

        self.griddedDataset = griddedDataset

        self.gridder = gridder

        self.cartesian_us = self.griddedDataset.coords.packed_u_centers_2D
        self.cartesian_vs = self.griddedDataset.coords.packed_v_centers_2D

        assert k > 0, "k must be a positive integer"
        self.k = k

        # 2d mask for any UV cells that contain visibilities
        # in *any* channel

        stacked_mask = np.any(self.griddedDataset.mask.detach().numpy(), axis=0)
        # add
        self.stacked_mask = stacked_mask

        # get u's and v's from dataset amd turn into 1D lists
        uu = self.gridder.uu
        vv = self.gridder.vv

        if npseed is not None:
            np.random.seed(npseed)

        # pairing u's and v's
        pairs = np.vstack((uu, vv)).T
        self.pairs = pairs

        # splitting
        l5000 = np.empty(2)
        g5000 = np.empty(2)

        pair_ind = 0
        while pair_ind < len(pairs):
            q = np.sqrt((pairs[pair_ind, 0]) ** 2 + (pairs[pair_ind, 1]) ** 2)
            if q < 5000:
                l5000 = np.vstack([l5000, pairs[pair_ind]])
            if q > 5000:
                g5000 = np.vstack([g5000, pairs[pair_ind]])
            pair_ind += 1
            print(pair_ind)

        # Doing this to get split function to work - should change in the long run
        # With the current dataset and criterion the l5000 array will be 3 too long to split evenly
        print("G1")
        l5000 = l5000[:-3]
        g5000 = g5000[:-1]

        self.l5000 = l5000
        self.g5000 = g5000
        print("G2")
        numSamplesFold = len(pairs) / k
        percentSmallSpatFreqFold = len(l5000) / len(pairs)
        numl5000perFold = numSamplesFold * percentSmallSpatFreqFold

        print("G3")
        percentLargeSpatFreqFold = len(g5000) / len(pairs)
        numg5000perFold = numSamplesFold * percentLargeSpatFreqFold

        # randomize each list while maintaining [uu, vv] pairs
        np.random.shuffle(l5000)
        np.random.shuffle(g5000)

        # get how many sections to partition l5000 into based on
        # numl5000perFold are necessary
        print("g4")
        numPairsl5000 = len(l5000)
        numSectionsl5000 = round(numPairsl5000 / numl5000perFold)
        self.numPairsl5000 = numPairsl5000
        self.numSectionsl5000 = numSectionsl5000

        # same for g5000
        print("G5")
        numPairsg5000 = len(g5000)
        numSectionsg5000 = round(numPairsg5000 / numg5000perFold)
        self.numPairsg5000 = numPairsg5000
        self.numSectionsg5000 = numSectionsg5000
        print("g6")
        # Partition low and high spat freq lists into groups
        lowSpatGroups = np.vsplit(l5000, numSectionsl5000)
        highSpatGroups = np.vsplit(g5000, numSectionsg5000)
        self.lowSpatGroups = lowSpatGroups
        self.highSpatGroups = highSpatGroups
        print("g7")
        # Create sets for kfolds
        # np array 7 X 983 X 2
        # 7 kfolds, 983 visibilities per fold, 2 coordinates [u, v]
        self.k_split_cell_list = np.hstack([lowSpatGroups, highSpatGroups])

    def build_grid_mask_from_cells(self, cell_index_list):
        # need to add to init self.cartesian_us = self.coords.pack_u_centers_2D
        mask = np.zeros_like(self.cartesian_us, dtype="bool")

        # problem, the following has actual u,v readings but does grid have problem with this?
        # for cell_index in cell_index_list:
        #   u, v = cell_index
        #   mask[u,v] = True

        for cell_index in cell_index_list:
            u, v = cell_index
            # u_min, u_max = self.u_edges[u : u + 2] #change bc u_edges is from coordinates
            # v_min, v_max = self.v_edges[v : v  + 2]
            u_min, u_max = self.coords.u_bin_min, self.coords.u_bin_max
            v_min, v_max = self.coords.v_bin_min, self.coords.v_bin_max

            # whether or not the u and v values of the coordinate array
            # fit in the u cell and v cell

            ind = (
                (self.cartesian_us >= u_min)
                & (self.cartesian_us < u_max)
                & (self.cartesian_vs >= v_min)
                & (self.cartesian_vs < v_max)
            )

            mask[ind] = True

        return mask

    def __iter__(self):
        self.n = 0  # the current k-slice we're on
        return self

    def __next__(self):
        print("Entered next")
        if self.n < self.k:
            print("Entered if")
            k_list = self.k_split_cell_list.copy()
            cell_list_test = k_list[self.n]
            self.cell_list_test = cell_list_test

            # put remaining indices back into a full list
            print("F1")
            cell_list_train = np.vstack(
                (k_list[: self.n, :, :], k_list[self.n + 1 :, :, :])
            )
            self.cell_list_train = cell_list_train

            # create the masks for each cell list
            print("F2")
            train_mask = self.build_grid_mask_from_cells(cell_list_train)
            test_mask = self.build_grid_mask_from_cells(cell_list_test)

            # copy origial dataset
            print("F3")
            train = copy.deepcopy(self.griddedDataset)
            test = copy.deepcopy(self.griddedDataset)

            # use these masks to limit new datasets to only unmasked cells
            print("F4")
            train.add_mask(train_mask)
            test.add_mask(test_mask)

            self.n += 1
            print(self.n)
            return train, test
        else:
            raise StopIteration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant