Skip to content

Commit

Permalink
Merge pull request #797 from Yodeman/master
Browse files Browse the repository at this point in the history
Fix slow validation and corrected weight loading method log message.
  • Loading branch information
OlafenwaMoses committed Mar 3, 2023
2 parents 5840670 + c7f7805 commit 64e6da4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 58 deletions.
9 changes: 5 additions & 4 deletions imageai/Detection/Custom/__init__.py
Expand Up @@ -113,12 +113,13 @@ def __load_model(self) -> None:
# to differ.
new_state_dict = {k:v for k,v in state_dict.items() if k in self.__model.state_dict().keys() and v.shape==self.__model.state_dict()[k].shape}
self.__model.load_state_dict(new_state_dict, strict=False)
print("="*20)
print("Pretrained YOLOv3 model loaded to initialize weights")
print("="*20)
except Exception as e:
print("="*20)
print("pretrained weight loading failed. Defaulting to using random weight.")

print("="*20)
print("Pretrained YOLOv3 model loaded to initialize weights")
print("="*20)
print("="*20)

def __load_data(self) -> None:
self.__num_classes = len(self.__classes)
Expand Down
71 changes: 17 additions & 54 deletions imageai/yolov3/utils.py
Expand Up @@ -4,6 +4,8 @@
import torch
import numpy as np
import cv2 as cv
from torchvision.ops import batched_nms


def draw_bbox_and_label(x : torch.Tensor, label : str, img : np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -180,6 +182,9 @@ def get_predictions(
--------
The prediction with reasonable bounding boxes.
"""
nB = pred.shape[0] # number of batches
bbox_attr = pred.shape[2] # center_x, center_y, height, width, class_probabilites
nBBOX = pred.shape[1] # number of bounding boxes
conf_mask = (pred[:, :, 4] > objectness_confidence).float().unsqueeze(2)
pred = pred * conf_mask

Expand All @@ -192,58 +197,16 @@ def get_predictions(
bbox_corner[:, :, 3] = (pred[:, :, 1] + (pred[:, :, 3] / 2)) # bottom_right_y
pred[:, :, :4] = bbox_corner[:, :, :4]

# each image in the batch will have varying numbers of true detections
output = None
for idx in range(pred.shape[0]):
img_pred = pred[idx]

# pick the class with maximum score, add the score and the index
# to the prediction.
max_conf, max_idx = torch.max(img_pred[:, 5:5+num_classes], 1)
max_conf = max_conf.float().unsqueeze(1).to(device)
max_idx = max_idx.float().unsqueeze(1).to(device)
img_pred = torch.cat([img_pred[:, :5], max_conf, max_idx], 1)

non_zero_idx = torch.nonzero(img_pred[:, 4]).to(device)
img_pred = img_pred[non_zero_idx.squeeze(), :].view(-1, 7).to(device)
if not img_pred.shape[0]:
continue

# get the unique classes detected in the image.
img_classes = torch.unique(img_pred[:, -1]).to(device)

# for each object in the image, get the one true bounding box that
# contains the object.
for cls in img_classes:
class_mask = img_pred * (img_pred[:, -1] == cls).float().unsqueeze(1)
class_mask_idx = torch.nonzero(class_mask[:, -2]).squeeze()
img_pred_class = img_pred[class_mask_idx].view(-1, 7)

# sort the detections in decreasing order of the objectness score
conf_sort_idx = torch.sort(img_pred_class[:, 4], descending=True)[1]
img_pred_class = img_pred_class[conf_sort_idx]

# since the bounding boxes have been sorted in decreasing order of the
# objectness score, pick the one with the maximum objectness score and
# use non-maximum suppression to remove all other boxes that might be
# detecting the same object as the one with max objectness score.
for d_idx in range(img_pred_class.shape[0]):
try:
ious = bbox_iou(img_pred_class[d_idx].unsqueeze(0), img_pred_class[d_idx+1:], device=device)
except (IndexError, ValueError):
break

# remove overlapping bounding boxes
iou_mask = (ious < nms_confidence_level).float().unsqueeze(1)
img_pred_class[d_idx+1:] *= iou_mask
non_zero_idx = torch.nonzero(img_pred_class[:, 4]).squeeze()
img_pred_class = img_pred_class[non_zero_idx].view(-1, 7)

batch_idx = img_pred_class.new(img_pred_class.shape[0], 1).fill_(idx)
if isinstance(output, torch.Tensor):
out = torch.cat([batch_idx, img_pred_class], 1)
output = torch.cat([output, out])
else:
output = torch.cat([batch_idx, img_pred_class], 1)
return output
n_pred = pred.view(-1, bbox_attr)
idxs = torch.arange(nB).reshape(-1,1).repeat(1, nBBOX).view(-1).to(device) # image indices

max_conf, max_idx = torch.max(n_pred[:, 5:5+num_classes], 1) # maximum class score and the index
max_conf = max_conf.float().unsqueeze(1).to(device)
max_idx = max_idx.float().unsqueeze(1).to(device)
n_pred = torch.cat([idxs.unsqueeze(1), n_pred[:, :5], max_conf, max_idx], 1) # batch_idx, x1, y1, x2, y2, objectness_score, class_score, class_idx

valid_bbox_indices = batched_nms(n_pred[:, 1:5].clone(), n_pred[:, 5].clone(), n_pred[:, 7].clone(), nms_confidence_level)

if len(valid_bbox_indices):
return n_pred[valid_bbox_indices, :]
return None

0 comments on commit 64e6da4

Please sign in to comment.