Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupedSplitter #2809

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8b1ecf4
adding grouped-splitter in notebook
Sep 16, 2020
3b9c7fe
update grouped splitter
YSaxon Sep 16, 2020
c629414
Merge remote-tracking branch 'origin/master'
YSaxon Sep 16, 2020
2c6fef5
minor (change word subgroups to groups)
YSaxon Sep 16, 2020
1aa1139
clean file with nb-clean
YSaxon Sep 17, 2020
a49c7d0
updated wording
YSaxon Sep 17, 2020
15472ad
change apply to applymap so it works correctly for elementwise funcs
YSaxon Sep 21, 2020
49cb3f8
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Oct 5, 2020
2d53a6f
undo most of nbclean to fix conflict
YSaxon Oct 5, 2020
a861599
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Oct 5, 2020
b6413fe
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Nov 5, 2020
b75c1da
enhanced with multiple splits and warnings
YSaxon Nov 5, 2020
491c3a3
docs
YSaxon Nov 5, 2020
bc30c16
more docs and one less test
YSaxon Nov 5, 2020
a7a9cd9
nbdev_build_lib
YSaxon Nov 5, 2020
a94864c
nbdev_clean_nbs
YSaxon Nov 5, 2020
d5d4d91
minor renames and doc changes
YSaxon Nov 12, 2020
fdbae86
changes to alice bob charlie example
YSaxon Nov 12, 2020
e65b9fd
zebra finch example
YSaxon Nov 12, 2020
dbc224d
Merge remote-tracking branch 'origin/master' into enhancedGroupedSpli…
YSaxon Nov 12, 2020
fa8bba1
removing warning since it isn't working
YSaxon Nov 12, 2020
a87158a
update checks.txt
YSaxon Nov 25, 2020
5f11f1e
Merge remote-tracking branch 'origin/master' into enhancedGroupedSpli…
YSaxon Nov 25, 2020
c2b3f0a
added #slow to zebra_finch tests
YSaxon Nov 25, 2020
a62a372
Split GroupedSplitter into two, one for lists, one for dfs
YSaxon Nov 25, 2020
6c83829
improve documentation
YSaxon Nov 25, 2020
6b13149
sync
hamelsmu Apr 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastai/_nbdev.py
Expand Up @@ -199,6 +199,8 @@
"FileSplitter": "05_data.transforms.ipynb",
"ColSplitter": "05_data.transforms.ipynb",
"RandomSubsetSplitter": "05_data.transforms.ipynb",
"GroupedDataframeSplitter": "05_data.transforms.ipynb",
"GroupedListSplitter": "05_data.transforms.ipynb",
"parent_label": "05_data.transforms.ipynb",
"RegexLabeller": "05_data.transforms.ipynb",
"ColReader": "05_data.transforms.ipynb",
Expand Down
2 changes: 1 addition & 1 deletion fastai/data/checks.txt
Expand Up @@ -224,7 +224,7 @@
"44fec3950e61d6a898f16fe30bc9c88d"
],
"https://storage.googleapis.com/ml-animal-sounds-datasets/zebra_finch.zip": [
83886080,
186624896,
"91fa9c4ebfc986b9babc2a805a10e281"
],
"https://s3.amazonaws.com/fast-ai-sample/mnist_sample.tgz": [
Expand Down
59 changes: 55 additions & 4 deletions fastai/data/transforms.py
Expand Up @@ -2,10 +2,11 @@

__all__ = ['get_files', 'FileGetter', 'image_extensions', 'get_image_files', 'ImageGetter', 'get_text_files',
'ItemGetter', 'AttrGetter', 'RandomSplitter', 'TrainTestSplitter', 'IndexSplitter', 'GrandparentSplitter',
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter', 'parent_label',
'RegexLabeller', 'ColReader', 'CategoryMap', 'Categorize', 'Category', 'MultiCategorize', 'MultiCategory',
'OneHotEncode', 'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor', 'IntToFloatTensor',
'broadcast_vec', 'Normalize']
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter',
'GroupedDataframeSplitter', 'GroupedListSplitter', 'parent_label', 'RegexLabeller', 'ColReader',
'CategoryMap', 'Categorize', 'Category', 'MultiCategorize', 'MultiCategory', 'OneHotEncode',
'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor', 'IntToFloatTensor', 'broadcast_vec',
'Normalize']

# Cell
from ..torch_basics import *
Expand Down Expand Up @@ -165,6 +166,56 @@ def _inner(o):
return idxs[:train_len],idxs[train_len:train_len+valid_len]
return _inner

# Cell
def _grouped_dataframe_splitter(df,group_col,valid_pct=0.2,seed=None,n_tries=3):
r_state=np.random.RandomState(seed)
desired_valid=round(len(df)*valid_pct)
gk=df.groupby(group_col).count()#make a table of groups and their counts
def one_shuffle():
shuffled_gk=gk.sample(frac=1,random_state=r_state) #shuffle the groups
cumsum=shuffled_gk.cumsum()
abs_diff=abs(cumsum-desired_valid)
split_goodness=-abs_diff.min().iat[0] #find the best split point for this shuffle
valid_rows=abs_diff.iloc[:,0].argmin()+1 #(the groups included in val for that split)
return shuffled_gk,split_goodness,valid_rows
def n_shuffles(n): #finding the closest possible split to valid_pct is NP hard so instead we just take the best of a few tries
best_shuffled,best_goodness,best_rows=one_shuffle()
for _ in range(n-1):
if best_goodness==0: return best_shuffled,best_goodness,best_rows #perfect split, return early
sh,g,r=one_shuffle()
if g>best_goodness:
best_shuffled,best_goodness,best_rows=sh,g,r
return best_shuffled,best_goodness,best_rows
shuffled_gk,split_goodness,valid_rows=n_shuffles(n_tries)
shuffled_gk['is_valid']=([True] * valid_rows +
[False]*(len(shuffled_gk) - valid_rows))
split_df=df.join(shuffled_gk.loc[:,'is_valid'],on=group_col) #apply the group split to the actual items
return ColSplitter()(split_df)

# Cell
def GroupedDataframeSplitter(group_col,colval2groupname=None,valid_pct=0.2,seed=None,n_tries=3):
"Splits items randomly without breaking up groups, to help ensure generalizability to unseen groups. Groups are defined by `group_col` with optional extra function `colval2groupname`"
def _inner(o):
assert isinstance(o,pd.DataFrame), 'This splitter is meant for Dataframes, please use GroupedListSplitter'
assert group_col in o, "`group_col` is not a valid column name in the DataFrame o"
df=pd.DataFrame(o)
if callable(colval2groupname):
df['group_keys']=df[group_col].apply(colval2groupname)
return _grouped_dataframe_splitter(df,'group_keys',valid_pct,seed,n_tries)
return _grouped_dataframe_splitter(df,group_col,valid_pct,seed,n_tries)
return _inner

# Cell
def GroupedListSplitter(item2group,valid_pct=0.2,seed=None,n_tries=3):
"Splits items randomly without breaking up groups, to help ensure generalizability to unseen groups. `itemtogroup` should be a function (eg a regex) that returns the group name for each item."
def _inner(o):
assert not isinstance(o,pd.DataFrame), 'Please use GroupedDataframeSplitter instead'
assert callable(item2group), "You must pass in a callable `item2group` that extracts a group name from each item"
df=pd.DataFrame(o)
df['group_keys']=df.applymap(item2group)
return _grouped_dataframe_splitter(df,'group_keys',valid_pct,seed,n_tries)
return _inner

# Cell
def parent_label(o):
"Label `item` with the parent folder name."
Expand Down