Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Per Class IoU to evaluate trained model #465

Open
MjdMahasneh opened this issue Oct 26, 2023 · 0 comments
Open

Using Per Class IoU to evaluate trained model #465

MjdMahasneh opened this issue Oct 26, 2023 · 0 comments

Comments

@MjdMahasneh
Copy link

MjdMahasneh commented Oct 26, 2023

first, let me thank you for the amazing repo, many thanks :)

In my project, I needed to evaluate using IoU (for consistency) and to get the per-class score. here is my evaluate_using_IoU.py (I have tested it and to the best of my knowledge it works as expected):

import logging
import os
import torch
from pathlib import Path

from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset


def iou_score(output, target):
    smooth = 1e-5 ## a small constant added to the numerator and denominator) is a common practice to prevent division by zero in cases where the intersection and union might be zero, leading to an undefined IoU value

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
        output = (output > 0.5).astype(int)
    else:
        output = (output > 0.5).astype(int)

    if torch.is_tensor(target):
        target = target.data.cpu().numpy().astype(int)

    intersection = (output & target).sum()
    union = (output | target).sum()

    iou = (intersection + smooth) / (union + smooth)

    return iou





@torch.inference_mode()
def evaluate_with_iou(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    total_iou = 0.0

    # Initializing a list to store the IoU for each class over all batches
    classwise_iou = [0.0] * net.n_classes

    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='IoU evaluation', unit='batch', leave=False):
            image, mask_true = batch['image'], batch['mask']

            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            mask_pred = net(image)

            batch_iou = 0.0  # IoU accumulator for the batch

            for cls in range(net.n_classes):  # Now including the background
                mask_pred_cls = (mask_pred.argmax(dim=1) == cls).float()
                mask_true_cls = (mask_true == cls).float()

                iou_cls = iou_score(mask_pred_cls, mask_true_cls)

                batch_iou += iou_cls
                classwise_iou[cls] += iou_cls  # Adding to the respective class

            batch_iou /= net.n_classes  # Average the IoU over all classes

            total_iou += batch_iou

    # Average classwise IoU over all batches
    classwise_iou = [iou / max(num_val_batches, 1) for iou in classwise_iou]

    return total_iou / max(num_val_batches, 1), classwise_iou


class Config:
    '''Configuration class for training
        Usage:
            args = Config()
            print(vars(args))
            print(args.epochs)
    '''
    def __init__(self):

        self.batch_size = 2
        self.bilinear = False
        self.classes = 3 #2
        self.target_size = (512, 512) ## (height, width)

        self.dir_root = Path('G:/Datasets')
        self.train_images_dir = Path(os.path.join(self.dir_root, 'train/images'))
        self.train_mask_dir = Path(os.path.join(self.dir_root, 'train/masks'))
        self.val_images_dir = Path(os.path.join(self.dir_root, 'val/images'))
        self.val_mask_dir = Path(os.path.join(self.dir_root, 'val/masks'))

        self.model = './checkpoints/checkpoint_epoch5.pth'








if __name__ == '__main__':

    args = Config()
    print('args : ', vars(args))


    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    state_dict = torch.load(args.model, map_location=device)
    del state_dict['mask_values']
    model.load_state_dict(state_dict)
    model.to(device=device)

    logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')


    assert args.val_images_dir is not None and args.val_mask_dir is not None, 'Please provide the path to the images directory'


    # Create datasets
    val_dataset = BasicDataset(args.val_images_dir, args.val_mask_dir, mask_suffix='', target_size=args.target_size, stage='val')
    n_val = len(val_dataset)
    loader_args = dict(batch_size=args.batch_size, num_workers=os.cpu_count(), pin_memory=True)

    # Create data loaders
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)



    logging.info(f'''Starting IoU evaluation:
        Batch size:      {args.batch_size}
        Validation size: {n_val}
        Device:          {device.type}
        ''')


    ## uncomment this to evaluate with Dice score
    # val_score = evaluate(model, val_loader, device, amp=False)
    # logging.info('Validation Dice score: {}'.format(val_score))

    val_score, classwise_scores = evaluate_with_iou(model, val_loader, device, amp=False)
    logging.info('Validation IoU score: {}'.format(val_score))
    for i, cls_iou in enumerate(classwise_scores):
        logging.info(f'Class {i} IoU score: {cls_iou}')



To run, just make sure you modify config class and run.

You could also include it in your repo if you think its useful.

Hope this helps.

It would be nice to get the per-class Dice Score too, maybe at some point in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant