-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
58 lines (40 loc) · 1.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torchvision
from torchvision import datasets,transforms, models
import os
import numpy as np
# import matplotlib.pyplot as plt
from torch.autograd import Variable
import time
folder_path = "/Users/krshrimali/Documents/krshrimali-blogs/dataset/train/train_python/"
transform = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor()])
data = datasets.ImageFolder(root = os.path.join(folder_path), transform = transform)
batch_size = 4
data_loader = torch.utils.data.DataLoader(dataset=data, batch_size = batch_size, shuffle = True)
model = models.resnet18(pretrained = True)
for parma in model.parameters():
parma.requires_grad = False
model.fc = torch.nn.Linear(512, 2)
for param in model.fc.parameters():
param.requires_grad = True
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters())
n_epochs = 15
for epoch in range(n_epochs):
mse = 0.0
acc = 0
batch_index = 0
for data_batch in data_loader:
batch_index += 1
image, label = data_batch
optimizer.zero_grad()
output = model(image)
_, predicted_label = torch.max(output.data, 1)
loss = cost(output, label)
loss.backward()
optimizer.step()
mse += loss.item() # data[0]
acc += torch.sum(predicted_label == label.data)
mse = mse/len(data)
acc = 100*acc/len(data)
print("Epoch: {}/{}, Loss: {:.4f}, Accuracy: {:.4f}".format(epoch+1, n_epochs, mse, acc))