Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupedSplitter #2809

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8b1ecf4
adding grouped-splitter in notebook
Sep 16, 2020
3b9c7fe
update grouped splitter
YSaxon Sep 16, 2020
c629414
Merge remote-tracking branch 'origin/master'
YSaxon Sep 16, 2020
2c6fef5
minor (change word subgroups to groups)
YSaxon Sep 16, 2020
1aa1139
clean file with nb-clean
YSaxon Sep 17, 2020
a49c7d0
updated wording
YSaxon Sep 17, 2020
15472ad
change apply to applymap so it works correctly for elementwise funcs
YSaxon Sep 21, 2020
49cb3f8
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Oct 5, 2020
2d53a6f
undo most of nbclean to fix conflict
YSaxon Oct 5, 2020
a861599
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Oct 5, 2020
b6413fe
Merge remote-tracking branch 'origin/master' into groupSplitter
YSaxon Nov 5, 2020
b75c1da
enhanced with multiple splits and warnings
YSaxon Nov 5, 2020
491c3a3
docs
YSaxon Nov 5, 2020
bc30c16
more docs and one less test
YSaxon Nov 5, 2020
a7a9cd9
nbdev_build_lib
YSaxon Nov 5, 2020
a94864c
nbdev_clean_nbs
YSaxon Nov 5, 2020
d5d4d91
minor renames and doc changes
YSaxon Nov 12, 2020
fdbae86
changes to alice bob charlie example
YSaxon Nov 12, 2020
e65b9fd
zebra finch example
YSaxon Nov 12, 2020
dbc224d
Merge remote-tracking branch 'origin/master' into enhancedGroupedSpli…
YSaxon Nov 12, 2020
fa8bba1
removing warning since it isn't working
YSaxon Nov 12, 2020
a87158a
update checks.txt
YSaxon Nov 25, 2020
5f11f1e
Merge remote-tracking branch 'origin/master' into enhancedGroupedSpli…
YSaxon Nov 25, 2020
c2b3f0a
added #slow to zebra_finch tests
YSaxon Nov 25, 2020
a62a372
Split GroupedSplitter into two, one for lists, one for dfs
YSaxon Nov 25, 2020
6c83829
improve documentation
YSaxon Nov 25, 2020
6b13149
sync
hamelsmu Apr 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions fastai/_nbdev.py
Expand Up @@ -196,6 +196,7 @@
"FileSplitter": "05_data.transforms.ipynb",
"ColSplitter": "05_data.transforms.ipynb",
"RandomSubsetSplitter": "05_data.transforms.ipynb",
"GroupedSplitter": "05_data.transforms.ipynb",
"parent_label": "05_data.transforms.ipynb",
"RegexLabeller": "05_data.transforms.ipynb",
"ColReader": "05_data.transforms.ipynb",
Expand Down
33 changes: 29 additions & 4 deletions fastai/data/transforms.py
Expand Up @@ -2,10 +2,10 @@

__all__ = ['get_files', 'FileGetter', 'image_extensions', 'get_image_files', 'ImageGetter', 'get_text_files',
'ItemGetter', 'AttrGetter', 'RandomSplitter', 'TrainTestSplitter', 'IndexSplitter', 'GrandparentSplitter',
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter', 'parent_label',
'RegexLabeller', 'ColReader', 'CategoryMap', 'Categorize', 'Category', 'MultiCategorize', 'MultiCategory',
'OneHotEncode', 'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor', 'IntToFloatTensor',
'broadcast_vec', 'Normalize']
'FuncSplitter', 'MaskSplitter', 'FileSplitter', 'ColSplitter', 'RandomSubsetSplitter', 'GroupedSplitter',
'parent_label', 'RegexLabeller', 'ColReader', 'CategoryMap', 'Categorize', 'Category', 'MultiCategorize',
'MultiCategory', 'OneHotEncode', 'EncodedMultiCategorize', 'RegressionSetup', 'get_c', 'ToTensor',
'IntToFloatTensor', 'broadcast_vec', 'Normalize']

# Cell
from ..torch_basics import *
Expand Down Expand Up @@ -164,6 +164,31 @@ def _inner(o):
return idxs[:train_len],idxs[train_len:train_len+valid_len]
return _inner

# Cell
def GroupedSplitter(groupkey,valid_pct=0.2, seed=None):
"Splits groups of items between train/val randomly, such that val should have close to `valid_pct` of the total number of items (similar to RandomSplitter). Groups are defined by a `groupkey`, a function/lambda to apply to individual items, or a colname if `o` is a DataFrame"
def _inner(o):
if callable(groupkey):
ids=pd.DataFrame(o)
ids['group_keys']=ids.applymap(groupkey)
keycol='group_keys'
else:
assert isinstance(o, pd.DataFrame), "o is not a DataFrame, so groupkey must be a function\lambda that extracts a group key from an item"
assert groupkey in o, "groupkey is not a colname in the DataFrame o"
keycol=groupkey
ids=o
gk=ids.groupby(keycol).count()
shuffled_gk=gk.sample(frac=1,random_state=seed)
cumsum=shuffled_gk.cumsum()
desired_valid=len(o)*valid_pct
abs_diff=abs(cumsum-desired_valid)
valid_rows=abs_diff.iloc[:,0].argmin()+1
shuffled_gk['is_valid']=([True] * valid_rows +
[False]*(len(shuffled_gk) - valid_rows))
split_df=ids.join(shuffled_gk.loc[:,'is_valid'],on=keycol)
return ColSplitter()(split_df)
return _inner

# Cell
def parent_label(o):
"Label `item` with the parent folder name."
Expand Down
76 changes: 76 additions & 0 deletions nbs/05_data.transforms.ipynb
Expand Up @@ -671,6 +671,82 @@
"test_eq(len(splits[1]), 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"def GroupedSplitter(groupkey,valid_pct=0.2, seed=None):\n",
" \"Splits groups of items between train/val randomly, such that val should have close to `valid_pct` of the total number of items (similar to RandomSplitter). Groups are defined by a `groupkey`, a function/lambda to apply to individual items, or a colname if `o` is a DataFrame\"\n",
" def _inner(o):\n",
" if callable(groupkey):\n",
" ids=pd.DataFrame(o)\n",
" ids['group_keys']=ids.applymap(groupkey)\n",
" keycol='group_keys'\n",
" else:\n",
" assert isinstance(o, pd.DataFrame), \"o is not a DataFrame, so groupkey must be a function\\lambda that extracts a group key from an item\"\n",
" assert groupkey in o, \"groupkey is not a colname in the DataFrame o\"\n",
" keycol=groupkey\n",
" ids=o\n",
" gk=ids.groupby(keycol).count()\n",
" shuffled_gk=gk.sample(frac=1,random_state=seed)\n",
" cumsum=shuffled_gk.cumsum()\n",
" desired_valid=len(o)*valid_pct\n",
" abs_diff=abs(cumsum-desired_valid)\n",
" valid_rows=abs_diff.iloc[:,0].argmin()+1\n",
" shuffled_gk['is_valid']=([True] * valid_rows + \n",
" [False]*(len(shuffled_gk) - valid_rows))\n",
" split_df=ids.join(shuffled_gk.loc[:,'is_valid'],on=keycol)\n",
" return ColSplitter()(split_df)\n",
" return _inner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"src = list(range(10000))\n",
"key_f=lambda x:x%100\n",
"f = GroupedSplitter(key_f,seed=42)\n",
"trn,val = f(src)\n",
"assert 0<len(trn)<len(src)\n",
"assert all(o not in val for o in trn)\n",
"k_trn=np.unique([key_f(o) for o in trn])\n",
"k_val=np.unique([key_f(o) for o in val])\n",
"assert all(k not in k_val for k in k_trn)\n",
"test_eq(len(trn), len(src)-len(val))\n",
"# # test random seed consistency\n",
"test_eq(f(src)[0], trn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = GroupedSplitter('keys',seed=41)\n",
"src = list(range(1000))\n",
"key_f=lambda x:x%10\n",
"src2=pd.DataFrame(src)\n",
"src2['keys']=src2.apply(key_f)\n",
"src2['conconfounding_col_1']='test'\n",
"src2.insert(0,'confounding_col_0','test')\n",
"trn,val=f(src2)\n",
"assert 0<len(trn)<len(src2)\n",
"assert all(o not in val for o in trn)\n",
"k_trn=np.unique([key_f(o) for o in trn])\n",
"k_val=np.unique([key_f(o) for o in val])\n",
"assert all(k not in k_val for k in k_trn)\n",
"test_eq(len(trn), len(src2)-len(val))\n",
"# # test random seed consistency\n",
"test_eq(f(src2)[0], trn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down