diff --git a/imageai/Detection/Custom/__init__.py b/imageai/Detection/Custom/__init__.py index 64079e63..e40c711d 100644 --- a/imageai/Detection/Custom/__init__.py +++ b/imageai/Detection/Custom/__init__.py @@ -113,12 +113,13 @@ def __load_model(self) -> None: # to differ. new_state_dict = {k:v for k,v in state_dict.items() if k in self.__model.state_dict().keys() and v.shape==self.__model.state_dict()[k].shape} self.__model.load_state_dict(new_state_dict, strict=False) + print("="*20) + print("Pretrained YOLOv3 model loaded to initialize weights") + print("="*20) except Exception as e: + print("="*20) print("pretrained weight loading failed. Defaulting to using random weight.") - - print("="*20) - print("Pretrained YOLOv3 model loaded to initialize weights") - print("="*20) + print("="*20) def __load_data(self) -> None: self.__num_classes = len(self.__classes) diff --git a/imageai/yolov3/utils.py b/imageai/yolov3/utils.py index d7deac3e..aedb0129 100644 --- a/imageai/yolov3/utils.py +++ b/imageai/yolov3/utils.py @@ -4,6 +4,8 @@ import torch import numpy as np import cv2 as cv +from torchvision.ops import batched_nms + def draw_bbox_and_label(x : torch.Tensor, label : str, img : np.ndarray) -> np.ndarray: """ @@ -180,6 +182,9 @@ def get_predictions( -------- The prediction with reasonable bounding boxes. """ + nB = pred.shape[0] # number of batches + bbox_attr = pred.shape[2] # center_x, center_y, height, width, class_probabilites + nBBOX = pred.shape[1] # number of bounding boxes conf_mask = (pred[:, :, 4] > objectness_confidence).float().unsqueeze(2) pred = pred * conf_mask @@ -192,58 +197,16 @@ def get_predictions( bbox_corner[:, :, 3] = (pred[:, :, 1] + (pred[:, :, 3] / 2)) # bottom_right_y pred[:, :, :4] = bbox_corner[:, :, :4] - # each image in the batch will have varying numbers of true detections - output = None - for idx in range(pred.shape[0]): - img_pred = pred[idx] - - # pick the class with maximum score, add the score and the index - # to the prediction. - max_conf, max_idx = torch.max(img_pred[:, 5:5+num_classes], 1) - max_conf = max_conf.float().unsqueeze(1).to(device) - max_idx = max_idx.float().unsqueeze(1).to(device) - img_pred = torch.cat([img_pred[:, :5], max_conf, max_idx], 1) - - non_zero_idx = torch.nonzero(img_pred[:, 4]).to(device) - img_pred = img_pred[non_zero_idx.squeeze(), :].view(-1, 7).to(device) - if not img_pred.shape[0]: - continue - - # get the unique classes detected in the image. - img_classes = torch.unique(img_pred[:, -1]).to(device) - - # for each object in the image, get the one true bounding box that - # contains the object. - for cls in img_classes: - class_mask = img_pred * (img_pred[:, -1] == cls).float().unsqueeze(1) - class_mask_idx = torch.nonzero(class_mask[:, -2]).squeeze() - img_pred_class = img_pred[class_mask_idx].view(-1, 7) - - # sort the detections in decreasing order of the objectness score - conf_sort_idx = torch.sort(img_pred_class[:, 4], descending=True)[1] - img_pred_class = img_pred_class[conf_sort_idx] - - # since the bounding boxes have been sorted in decreasing order of the - # objectness score, pick the one with the maximum objectness score and - # use non-maximum suppression to remove all other boxes that might be - # detecting the same object as the one with max objectness score. - for d_idx in range(img_pred_class.shape[0]): - try: - ious = bbox_iou(img_pred_class[d_idx].unsqueeze(0), img_pred_class[d_idx+1:], device=device) - except (IndexError, ValueError): - break - - # remove overlapping bounding boxes - iou_mask = (ious < nms_confidence_level).float().unsqueeze(1) - img_pred_class[d_idx+1:] *= iou_mask - non_zero_idx = torch.nonzero(img_pred_class[:, 4]).squeeze() - img_pred_class = img_pred_class[non_zero_idx].view(-1, 7) - - batch_idx = img_pred_class.new(img_pred_class.shape[0], 1).fill_(idx) - if isinstance(output, torch.Tensor): - out = torch.cat([batch_idx, img_pred_class], 1) - output = torch.cat([output, out]) - else: - output = torch.cat([batch_idx, img_pred_class], 1) - return output + n_pred = pred.view(-1, bbox_attr) + idxs = torch.arange(nB).reshape(-1,1).repeat(1, nBBOX).view(-1).to(device) # image indices + + max_conf, max_idx = torch.max(n_pred[:, 5:5+num_classes], 1) # maximum class score and the index + max_conf = max_conf.float().unsqueeze(1).to(device) + max_idx = max_idx.float().unsqueeze(1).to(device) + n_pred = torch.cat([idxs.unsqueeze(1), n_pred[:, :5], max_conf, max_idx], 1) # batch_idx, x1, y1, x2, y2, objectness_score, class_score, class_idx + + valid_bbox_indices = batched_nms(n_pred[:, 1:5].clone(), n_pred[:, 5].clone(), n_pred[:, 7].clone(), nms_confidence_level) + if len(valid_bbox_indices): + return n_pred[valid_bbox_indices, :] + return None