-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added brute force method to preserved site k-fold split #166
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Rohan - please see my comments here. There are a couple of phases we'll take for this review - python basics, overall algorithm design, and then optimizations. I've focused mostly on python basics here. With your next update, I hope to move to overall algorithm design.
I would recommend reading about list (and dictionary) comprehension - it's Python magic that both speeds up code and significantly improves readability. Many of the loops you have could probably be replaced with some form of list/dictionary comprehension.
We can talk about the overall algorithm next, but I wanted to work on improving the readability a bit, as it becomes MUCH easier to optimize an algorithm if the code is already pretty lean.
# test using fake data generator | ||
# data = generate_test_data() | ||
# list_of_split = generate(data[0], data[1], data[2], data[3]) | ||
# print(list_of_split) | ||
# dictionary = generate_brute_force(data[0], data[1], data[2], data[3]) | ||
# print(dictionary) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above regarding test functions in the main module.
Additionally, you shouldn't leave commented out code in a committed file - it should either be deleted, or put into a separate function.
def generate_test_data(): | ||
|
||
k = 3 | ||
patients = [f'pt{p}' for p in range(200)] | ||
sites = [f'site{s}' for s in range(5)] | ||
outcomes = list(range(4)) | ||
category = 'outcome_label' | ||
df = pd.DataFrame({ | ||
'patient': pd.Series([random.choice(patients) for _ in range(100)]), | ||
'site': pd.Series([random.choice(sites) for _ in range(100)]), | ||
'outcome_label': pd.Series([random.choice(outcomes) for _ in range(100)]) | ||
}) | ||
unique_labels = df['outcome_label'].unique() | ||
return [df, category, unique_labels, k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test functions shouldn't be in the main module - rather, they should be in separate files used for unit testing. In this case, the site splits have unit tests under slideflow/test/dataset_test.py
:
https://github.com/jamesdolezal/slideflow/blob/master/slideflow/test/dataset_test.py
You could add an additional unit test in the TestSplits
class, which tests your new method. The TestSplits.setUpClass()
method prepares some test data, which is then used by each of the unit tests.
For this exercise, I'll have you hold off on implementing unit tests, as I recently made changes to that module that haven't yet been integrated. We'll circle back to the unit tests later.
|
||
# list of possible combinations of sites in folds; also built in check for use case when someone chooses the wrong number of folds | ||
if crossfolds > len(unique_sites): | ||
print("choose less number of crossfolds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error handling should use the try...except
Python functionality rather than if...then
statements, as it allows for large applications to properly report and handle errors.
In this case, is the number of crossfolds is greater than the number of unique sites, we would want to raise an error, since it's not possible to perform site-preserved splitting. The type of error that should be raised will depend on the surrounding application and should make sense. In this case, the best error to raise would probably be slideflow.errors.DatasetSplitError
. So you would change that block to:
if crossfolds > len(unique_sites):
raise sf.errors.DatasetSplitError(
"Insufficient number of sites ({}) for number of crossfolds ({})".format(
len(unique_sites),
crossfolds))
That also enables you to get rid of the following else
statement and de-indent the next code block by a level.
print("choose less number of crossfolds") | ||
else: | ||
most_possible_sites_in_one_fold = 1 + int(len(unique_sites)) - crossfolds | ||
list_of_all_combos = list() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of setting variables to list()
and dict()
, python convention is to use []
and {}
if crossfolds > len(unique_sites): | ||
print("choose less number of crossfolds") | ||
else: | ||
most_possible_sites_in_one_fold = 1 + int(len(unique_sites)) - crossfolds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the int()
is redundant here, as len()
always returns an integer
list_of_proper_combos = list(combinations(list_of_all_combos, crossfolds)) | ||
removal_list = list() | ||
for item in list_of_proper_combos: | ||
item_length = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than looping and counting lengths with the item_length
counter, the clean pythonic way of doing this would be:
sum([len(i) for i in item])
list_of_items_in_a_combo = list() | ||
for item2 in item: | ||
item_length += len(item2) | ||
list_of_items_in_a_combo.extend(item2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the role of list_of_items_in_a_combo
here? Is it not the same as item
after the loop on lines 52-54?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the item is a list of lists and I need the list of sites from all of those lists; I changed it to a list comprehension instead.
sites_in_a_possible_crossfold = [site for site in possible_crossfold] where
- possible_crossfold was item
- where sites_in_a_possible_crossfold was list_of_items_in_a_combo
@jamesdolezal does that work better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line:
sites_in_a_possible_crossfold = [site for site in possible_crossfold]
is equivalent to:
sites_in_a_possible_crossfold = possible_crossfold
If this is what you are intending, you should consider removing the sites_in_a_possible_crossfold
entirely and just use possible_crossfold
. Does that make sense, or am I missing something?
data_dict[site] = dict_of_values | ||
|
||
# error associated to each possible combo | ||
per_fold_size_target_ratio = float(1)/crossfolds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of float(1)
, do 1.
mean_square_error = 0 | ||
count = 0 | ||
for fold in combo: | ||
sum2 = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum2
is not a descriptive variable name
per_combo_errors[combo] = mean_square_error | ||
|
||
# isolate best combo by error | ||
min = 100000000000000000000000000000000000000000000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of min
here? What are you actually trying to check for with the if value < min
statement?
Additionally, min()
is a python standard function (finds the minimum of a list, similar to max()
, so you should name your variable something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i am using that as an initial value to start comparing the errors in the rest of the dictionary; the value<min will check to see if the values are less than that. If so, it resets min to be value and continues looking for the smallest error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. In that case, you should remove the min
variable (having a manually set large value is clunky). Instead, you can just sort the dictionary and find the lowest value. So instead of:
min = 100000000000000000000000000000000000000000000
best_combo = None
for key, value in per_combo_errors.items():
if value < min:
best_combo = key
min = value
else:
pass
do
best_combo = min(per_combo_errors, key=per_combo_errors.get)
Fixes bug caused by an invalid variable name (`sum` should not be used as a variable name, as it will override the python default sum() function)
In a083554, I add a unit test that ensures the generated splits are valid (all patients are used, no site is present in multiple cross-folds). The test generated data was copied from your original submission. Running the test only requires:
|
This includes formatting changes, updated variable names, and refactoring only to improve readability. The underlying algorithm should be close to identical.
In 57ac792, I made formatting changes with minor refactoring only to improve readability - the underlying algorithm is the same. Code that is easier to read is also easier to refactor and optimize. The kinds of changes I made include:
With the unit testing now available, we can ensure that the splits that are being generated are valid. Running the unit test both pre- and post-refactor raises an error, indicating an issue with the algorithm. For this next step, I'll have you familiarize yourself with the modified code, and track down the cause of the failed unit test. Once the unit test passes indicating that the algorithm is complete, we will move on to optimization. |
…to do with the join it may be leaving out some patient data
788d9c9
to
7c1081d
Compare
be66b6c
to
85ecdf5
Compare
7c7e8dc
to
e56efb4
Compare
No description provided.