Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added brute force method to preserved site k-fold split #166

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

rsethi21
Copy link

No description provided.

@jamesdolezal jamesdolezal self-requested a review June 20, 2022 23:10
Copy link
Owner

@jamesdolezal jamesdolezal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Rohan - please see my comments here. There are a couple of phases we'll take for this review - python basics, overall algorithm design, and then optimizations. I've focused mostly on python basics here. With your next update, I hope to move to overall algorithm design.

I would recommend reading about list (and dictionary) comprehension - it's Python magic that both speeds up code and significantly improves readability. Many of the loops you have could probably be replaced with some form of list/dictionary comprehension.

We can talk about the overall algorithm next, but I wanted to work on improving the readability a bit, as it becomes MUCH easier to optimize an algorithm if the code is already pretty lean.

Comment on lines 201 to 206
# test using fake data generator
# data = generate_test_data()
# list_of_split = generate(data[0], data[1], data[2], data[3])
# print(list_of_split)
# dictionary = generate_brute_force(data[0], data[1], data[2], data[3])
# print(dictionary)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above regarding test functions in the main module.

Additionally, you shouldn't leave commented out code in a committed file - it should either be deleted, or put into a separate function.

Comment on lines 11 to 24
def generate_test_data():

k = 3
patients = [f'pt{p}' for p in range(200)]
sites = [f'site{s}' for s in range(5)]
outcomes = list(range(4))
category = 'outcome_label'
df = pd.DataFrame({
'patient': pd.Series([random.choice(patients) for _ in range(100)]),
'site': pd.Series([random.choice(sites) for _ in range(100)]),
'outcome_label': pd.Series([random.choice(outcomes) for _ in range(100)])
})
unique_labels = df['outcome_label'].unique()
return [df, category, unique_labels, k]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test functions shouldn't be in the main module - rather, they should be in separate files used for unit testing. In this case, the site splits have unit tests under slideflow/test/dataset_test.py:

https://github.com/jamesdolezal/slideflow/blob/master/slideflow/test/dataset_test.py

You could add an additional unit test in the TestSplits class, which tests your new method. The TestSplits.setUpClass() method prepares some test data, which is then used by each of the unit tests.

For this exercise, I'll have you hold off on implementing unit tests, as I recently made changes to that module that haven't yet been integrated. We'll circle back to the unit tests later.


# list of possible combinations of sites in folds; also built in check for use case when someone chooses the wrong number of folds
if crossfolds > len(unique_sites):
print("choose less number of crossfolds")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error handling should use the try...except Python functionality rather than if...then statements, as it allows for large applications to properly report and handle errors.

In this case, is the number of crossfolds is greater than the number of unique sites, we would want to raise an error, since it's not possible to perform site-preserved splitting. The type of error that should be raised will depend on the surrounding application and should make sense. In this case, the best error to raise would probably be slideflow.errors.DatasetSplitError. So you would change that block to:

if crossfolds > len(unique_sites):
    raise sf.errors.DatasetSplitError(
        "Insufficient number of sites ({}) for number of crossfolds ({})".format(
            len(unique_sites),
            crossfolds))

That also enables you to get rid of the following else statement and de-indent the next code block by a level.

print("choose less number of crossfolds")
else:
most_possible_sites_in_one_fold = 1 + int(len(unique_sites)) - crossfolds
list_of_all_combos = list()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of setting variables to list() and dict(), python convention is to use [] and {}

if crossfolds > len(unique_sites):
print("choose less number of crossfolds")
else:
most_possible_sites_in_one_fold = 1 + int(len(unique_sites)) - crossfolds
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the int() is redundant here, as len() always returns an integer

list_of_proper_combos = list(combinations(list_of_all_combos, crossfolds))
removal_list = list()
for item in list_of_proper_combos:
item_length = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than looping and counting lengths with the item_length counter, the clean pythonic way of doing this would be:

sum([len(i) for i in item])

list_of_items_in_a_combo = list()
for item2 in item:
item_length += len(item2)
list_of_items_in_a_combo.extend(item2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the role of list_of_items_in_a_combo here? Is it not the same as item after the loop on lines 52-54?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the item is a list of lists and I need the list of sites from all of those lists; I changed it to a list comprehension instead.

sites_in_a_possible_crossfold = [site for site in possible_crossfold] where

  • possible_crossfold was item
  • where sites_in_a_possible_crossfold was list_of_items_in_a_combo
    @jamesdolezal does that work better

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line:

sites_in_a_possible_crossfold = [site for site in possible_crossfold]

is equivalent to:

sites_in_a_possible_crossfold = possible_crossfold

If this is what you are intending, you should consider removing the sites_in_a_possible_crossfold entirely and just use possible_crossfold. Does that make sense, or am I missing something?

data_dict[site] = dict_of_values

# error associated to each possible combo
per_fold_size_target_ratio = float(1)/crossfolds
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of float(1), do 1.

mean_square_error = 0
count = 0
for fold in combo:
sum2 = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum2 is not a descriptive variable name

per_combo_errors[combo] = mean_square_error

# isolate best combo by error
min = 100000000000000000000000000000000000000000000
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of min here? What are you actually trying to check for with the if value < min statement?

Additionally, min() is a python standard function (finds the minimum of a list, similar to max(), so you should name your variable something else.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i am using that as an initial value to start comparing the errors in the rest of the dictionary; the value<min will check to see if the values are less than that. If so, it resets min to be value and continues looking for the smallest error.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case, you should remove the min variable (having a manually set large value is clunky). Instead, you can just sort the dictionary and find the lowest value. So instead of:

min = 100000000000000000000000000000000000000000000
best_combo = None
for key, value in per_combo_errors.items():
    if value < min:
        best_combo = key
        min = value
    else:
        pass

do

best_combo = min(per_combo_errors, key=per_combo_errors.get)

rsethi21 and others added 4 commits June 22, 2022 09:47
Fixes bug caused by an invalid variable name (`sum` should not be used as a variable name, as it will override the python default sum() function)
@jamesdolezal
Copy link
Owner

jamesdolezal commented Jun 23, 2022

In a083554, I add a unit test that ensures the generated splits are valid (all patients are used, no site is present in multiple cross-folds). The test generated data was copied from your original submission.

Running the test only requires:

python3 crossfolds_test.py

This includes formatting changes, updated variable
names, and refactoring only to improve readability.
The underlying algorithm should be close to identical.
@jamesdolezal
Copy link
Owner

In 57ac792, I made formatting changes with minor refactoring only to improve readability - the underlying algorithm is the same. Code that is easier to read is also easier to refactor and optimize. The kinds of changes I made include:

  • Added typing and docstring to the function declaration, to make it clear what the input arguments are and what the function does.
  • Broke long code blocks into smaller discrete sections, with accompanying comments to explain what each section does.
  • More succinct and easily interpretable variable names
  • Line length of 80
  • List comprehension to reduce the number of nested loops

With the unit testing now available, we can ensure that the splits that are being generated are valid. Running the unit test both pre- and post-refactor raises an error, indicating an issue with the algorithm.

For this next step, I'll have you familiarize yourself with the modified code, and track down the cause of the failed unit test. Once the unit test passes indicating that the algorithm is complete, we will move on to optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants