Skip to content

Commit 9dbdc19

Browse files
lrFinder_way_more_doc
1 parent 0c919be commit 9dbdc19

14 files changed

+3574
-72
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- Check features.md for a full list
1212

1313
## Documentation
14+
- Okay so writing examples will take forever, so please check function documentation and the demos folder for most of what you need
1415
- Check the syntax folder for documentation
1516
- Check demos folder
1617
- [Function documentation](https://subhadityamukherjee.github.io/sprintdl/)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# text_representation:
5+
# extension: .py
6+
# format_name: light
7+
# format_version: '1.5'
8+
# jupytext_version: 1.10.1
9+
# kernelspec:
10+
# display_name: Python 3
11+
# language: python
12+
# name: python3
13+
# ---
14+
15+
# +
16+
# %load_ext autoreload
17+
# %autoreload 2
18+
19+
# %matplotlib inline
20+
21+
import os
22+
23+
os.environ["TORCH_HOME"] = "/media/hdd/Datasets/"
24+
import sys
25+
26+
sys.path.append("../")
27+
# -
28+
29+
from sprintdl.main import *
30+
from sprintdl.nets import *
31+
32+
device = torch.device("cuda", 0)
33+
import math
34+
35+
import torch
36+
from torch.nn import init
37+
38+
# # Define required
39+
40+
# +
41+
fpath = Path("/media/hdd/Datasets/ArtClass/")
42+
43+
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]
44+
bs = 256
45+
# -
46+
47+
# # Actual process
48+
49+
il = ImageList.from_files(fpath, tfms=tfms)
50+
51+
il
52+
53+
tm = Path(
54+
"/media/hdd/Datasets/ArtClass/Unpopular/mimang.art/69030963_140928767119437_3621699865915593113_n.jpg"
55+
)
56+
57+
str(tm).split("/")[-3]
58+
59+
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.2))
60+
ll = label_by_func(sd, lambda x: str(x).split("/")[-3], proc_y=CategoryProcessor())
61+
62+
n_classes = len(set(ll.train.y.items))
63+
64+
data = ll.to_databunch(bs, c_in=3, c_out=2)
65+
66+
show_batch(data, 4)
67+
68+
# +
69+
lr = 0.001
70+
pct_start = 0.5
71+
phases = create_phases(pct_start)
72+
sched_lr = combine_scheds(phases, cos_1cycle_anneal(lr / 10.0, lr, lr / 1e5))
73+
sched_mom = combine_scheds(phases, cos_1cycle_anneal(0.95, 0.85, 0.95))
74+
75+
cbfs = [
76+
partial(AvgStatsCallback, accuracy),
77+
partial(ParamScheduler, "lr", sched_lr),
78+
partial(ParamScheduler, "mom", sched_mom),
79+
partial(BatchTransformXCallback, norm_imagenette),
80+
ProgressCallback,
81+
Recorder,
82+
# MixUp,
83+
partial(CudaCallback, device),
84+
]
85+
86+
loss_func = LabelSmoothingCrossEntropy()
87+
# arch = partial(xresnet34, n_classes)
88+
arch = get_vision_model("resnet34", n_classes=n_classes, pretrained=True)
89+
90+
# opt_func = partial(sgd_mom_opt, wd=0.01)
91+
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)
92+
# opt_func = lamb
93+
# -
94+
95+
# # Training
96+
97+
clear_memory()
98+
99+
# learn = get_learner(nfs, data, lr, conv_layer, cb_funcs=cbfs)
100+
learn = Learner(arch, data, loss_func, lr=lr, cb_funcs=cbfs, opt_func=opt_func)
101+
102+
# +
103+
# model_summary(learn, data)
104+
# -
105+
106+
learn.fit(1)
107+
108+
save_model(learn, "m1", fpath)
109+
110+
# +
111+
temp = Path(
112+
"/media/hdd/Datasets/ArtClass/Popular/artgerm/10004370_1657536534486515_1883801324_n.jpg"
113+
)
114+
115+
get_class_pred(temp, learn, ll, 128)
116+
# -
117+
118+
temp = Path("/home/eragon/Downloads/Telegram Desktop/IMG_1800.PNG")
119+
120+
get_class_pred(temp, learn, ll, 128)
121+
122+
temp = Path("/home/eragon/Downloads/Telegram Desktop/IMG_20210106_180731.jpg")
123+
124+
get_class_pred(temp, learn, ll, 128)
125+
126+
# # Digging in
127+
128+
# +
129+
# classification_report(learn, n_classes, device)
130+
# -
131+
132+
learn.recorder.plot_lr()
133+
134+
learn.recorder.plot_loss()
135+
136+
# # Model vis
137+
138+
run_with_act_vis(1, learn)
139+
140+
# # Multiple runs with model saving
141+
142+
dict_runner = {
143+
"xres18": [
144+
1,
145+
partial(xresnet18, c_out=n_classes)(),
146+
data,
147+
loss_func,
148+
0.001,
149+
cbfs,
150+
opt_func,
151+
],
152+
"xres34": [
153+
1,
154+
partial(xresnet34, c_out=n_classes)(),
155+
data,
156+
loss_func,
157+
0.001,
158+
cbfs,
159+
opt_func,
160+
],
161+
"xres50": [
162+
1,
163+
partial(xresnet50, c_out=n_classes)(),
164+
data,
165+
loss_func,
166+
0.001,
167+
cbfs,
168+
opt_func,
169+
],
170+
}
171+
172+
learn = Learner(arch(), data, loss_func, lr=lr, cb_funcs=cbfs, opt_func=opt_func)
173+
174+
multiple_runner(dict_runner, fpath)

0 commit comments

Comments
 (0)