Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I used Actnn in Deeplab v3 plus ResNet50, and the it degrade the iou by %2 #29

Open
hewiew opened this issue Oct 29, 2021 · 9 comments
Open

Comments

@hewiew
Copy link

hewiew commented Oct 29, 2021

No description provided.

@hewiew hewiew changed the title I used Actnn in Deeplab v3 plus ResNet50, and the it degrade the iou 边缘%2 I used Actnn in Deeplab v3 plus ResNet50, and the it degrade the iou by %2 Oct 29, 2021
@hewiew
Copy link
Author

hewiew commented Oct 29, 2021

Here is my codes (part):

import actnn
from actnn import QModule, QScheme

actnn.set_optimization_level("L3")

def main():

os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)

# Setup random seed
torch.manual_seed(opts.random_seed)
torch.cuda.manual_seed(opts.random_seed)
# torch.cuda.manual_seed_all(opts.random_seed)
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)

# Setup dataloader
if opts.dataset=='voc' and not opts.crop_val:
    opts.val_batch_size = 1

train_dst, val_dst = get_dataset(opts)
train_loader = data.DataLoader(
    train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=4)
val_loader = data.DataLoader(
    val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=4)
print("Dataset: %s, Train set: %d, Val set: %d" %
      (opts.dataset, len(train_dst), len(val_dst)))

step_per_epoch = int(len(train_dst) / opts.batch_size) + 1
if opts.lr_policy == 'cos':
    T_max = int(opts.batch_size * opts.total_itrs / (len(train_dst) + 1e-6)) + 1
    print("lr_policy: cos, T_max: %d" % (T_max))

# Set up model
model_map = {
    'deeplabv3_resnet50': network.deeplabv3_resnet50,
    'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
    'deeplabv3_resnet101': network.deeplabv3_resnet101,
    'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
    'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
    'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet,
    'deeplabv3plus_fuse_mobilenet': network.deeplabv3plus_fuse_mobilenet
}
if opts.norm_layer == 'GN':
    # norm_layer = nn.GroupNorm
    norm_layer = network.GroupNorm32
elif opts.norm_layer == 'FRN':
    norm_layer = network.FilterResponseNorm2d
else:
    norm_layer = None
model = model_map[opts.model](input_channels=opts.input_channels,
                              num_classes=opts.num_classes,
                              output_stride=opts.output_stride,
                              crop_size=opts.crop_size,
                              stochastic_depth_rate=opts.stochastic_depth_rate,
                              norm_layer=norm_layer)

if opts.use_actnn:
    model = actnn.QModule(model)
    # if isinstance(model, QModule):
        # model = model.model

if opts.separable_conv and 'plus' in opts.model:
    network.convert_to_separable_conv(model.model.classifier)
utils.set_bn_momentum(model.model.backbone, momentum=0.01)

# Set up metrics
metrics = StreamSegMetrics(opts.num_classes)

# Set up optimizer
optimizer = torch.optim.SGD(params=[
    {'params': model.model.backbone.parameters(), 'lr': 0.1*opts.lr},
    {'params': model.model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
#optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
#torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)

if opts.lr_policy=='poly':
    scheduler = utils.PolyLR(optimizer, opts.total_itrs - opts.warm_up_epoch * step_per_epoch, power=0.9)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1,
                                              total_epoch=opts.warm_up_epoch * step_per_epoch + 1, after_scheduler=scheduler)
elif opts.lr_policy=='step':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size * (
                opts.total_itrs - opts.warm_up_epoch * step_per_epoch - 1) / opts.step_size, gamma=0.1)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=opts.warm_up_epoch * step_per_epoch + 1,
                                              after_scheduler=scheduler)
elif opts.lr_policy=='cos':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opts.total_itrs - opts.warm_up_epoch * step_per_epoch)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=opts.warm_up_epoch * step_per_epoch + 1, after_scheduler=scheduler)

after_scheduler=scheduler)

# Set up criterion
#criterion = utils.get_loss(opts.loss_type)
if opts.loss_type == 'focal_loss':
    criterion = utils.FocalLoss(ignore_index=255, size_average=True)
elif opts.loss_type == 'cross_entropy':
    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
elif opts.loss_type == 'gce_loss':
    criterion = utils.TruncatedLoss(trainset_size=len(train_dst))
elif opts.loss_type == 'sce_loss':
    criterion = utils.SCELoss()
elif opts.loss_type == 'bitl_loss':
    criterion = utils.BiTemperedLogisticLoss(t1=1.0, t2=1.0, label_smoothing=0.0)
elif opts.loss_type == 'dmi_loss':
    criterion = utils.DMI_loss()

def save_ckpt(path):
    """ save current model
    """
    torch.save({
        "cur_itrs": cur_itrs,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler_warmup.state_dict(),
        "best_score": best_score,
    }, path)
    print("Model saved as %s" % path)

utils.mkdir('checkpoints')
# Restore
best_score = 0.0
best_class_score = {'0': 0, '1': 0}
best_class_recall = {'0': 0, '1': 0}
best_class_precision = {'0': 0, '1': 0}
best_class_f1 = {'0': 0, '1': 0}
cur_itrs = 0
cur_epochs = 0

if opts.ckpt is not None and os.path.isfile(opts.ckpt):
    # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
    checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state"])
    model = nn.DataParallel(model)
    model.to(device)
    if opts.continue_training:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler_warmup.load_state_dict(checkpoint["scheduler_state"])
        cur_itrs = checkpoint["cur_itrs"]
        best_score = checkpoint['best_score']
        print("Training state restored from %s" % opts.ckpt)
    print("Model restored from %s" % opts.ckpt)
    del checkpoint  # free memory
else:
    print("[!] Retrain")
    model = nn.DataParallel(model)
    model.to(device)

#==========   Train Loop   ==========#
vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
                                  np.int32) if opts.enable_vis else None  # sample idxs for visualization

denorm = utils.Denormalize(mean=opts.mean, std=opts.std)

# denorm = utils.Denormalize(mean = [0.283, 0.284, 0.254], std = [0.181, 0.160, 0.144])  # denormalization for ori images

if opts.test_only:
    model.eval()
    val_score, ret_samples = validate(
        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
    print(metrics.to_str(val_score))
    return

interval_loss = 0
# train_start_time = time.time()
while True: #cur_itrs < opts.total_itrs:
    # =====  Train  =====
    model.train()
    cur_epochs += 1
    for (images, labels) in train_loader:
        # time1 = time.time()
        if images.shape[0] != opts.batch_size:
            continue
        start_time = time.time()
        cur_itrs += 1

        images = images.to(device, dtype=torch.float32)
        labels = labels.to(device, dtype=torch.long)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        scheduler_warmup.step()

        np_loss = loss.detach().cpu().numpy()
        interval_loss += np_loss


        if (cur_itrs) % opts.print_interval == 0:
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)
                vis.vis_scalar('Learning Rate', cur_itrs, optimizer.param_groups[1]['lr'])
            interval_loss = interval_loss/opts.print_interval
            print("Epoch %d, Itrs %d/%d, Loss=%f, Lr=%.6f [%.4f s]" %
                  (cur_epochs, cur_itrs, opts.total_itrs, interval_loss, optimizer.param_groups[1]['lr'], time.time() - start_time))

            interval_loss = 0.0
        if (cur_itrs) % opts.val_interval == 0:
            save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                      (opts.model, opts.dataset, opts.output_stride))
            print("validation...")
            val_start_time = time.time()
            model.eval()
            val_score, ret_samples = validate(
                opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
            print(metrics.to_str(val_score))
            print("Validating[%.4f s]" % (time.time() - val_start_time))

            if val_score['Class IoU'][1] > best_score:  # save best model
                best_score = val_score['Class IoU'][1]
                best_class_score = val_score['Class IoU']
                best_class_recall = val_score['Class Recall']
                best_class_precision = val_score['Class Precision']
                best_class_f1 = val_score['Class F1-score']
                save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset,opts.output_stride))

            model.train()

        if cur_itrs >=  opts.total_itrs:
            return

    if vis.env != 'main':
        vis.vis.save(envs=[vis.env])

if name == 'main':
main()

@cjf00000
Copy link
Collaborator

ActNN is a lossy algorithm, so it is possible that it does not work with 2 bits for some models.

Please try using more warmup iterations, and

actnn.set_optimization_level("L2")

@hewiew
Copy link
Author

hewiew commented Oct 29, 2021

ActNN is a lossy algorithm, so it is possible that it does not work with 2 bits for some models.

Please try using more warmup iterations, and

actnn.set_optimization_level("L2")

I've tried set_optimization_level "L0", "L1", "L2", "L3", but the performance seems no much difference, so I'm confused

as for warmup iterations, i used 4 epoch warmup, is that enough for training?

@cjf00000
Copy link
Collaborator

ActNN L0 does exactly the same thing with full precision training. Is the 2% accuracy loss within random error?

@cjf00000
Copy link
Collaborator

Could you print(model) before the training loop, and check if the model is correctly converted?

ActNN converts nn.Modules with its own modules, and I noticed there are additional model converters after
actnn.QModules. If these converters are looking for the original nn.Modules (e.g., nn.BatchNorm), they may not found the corresponding module. You can try moving actnn.QModule after these converters.

if opts.separable_conv and 'plus' in opts.model:
network.convert_to_separable_conv(model.model.classifier)
utils.set_bn_momentum(model.model.backbone, momentum=0.01)

@hewiew
Copy link
Author

hewiew commented Oct 29, 2021

ActNN L0 does exactly the same thing with full precision training. Is the 2% accuracy loss within random error?

I'm afraid not, when I don't use the actnn.module wrapper, I can always get a model with 2% of iou higher than that use it. Incidentally, my model is a 2-class segmentation model, Does this have an impact?

@hewiew
Copy link
Author

hewiew commented Oct 29, 2021

Could you print(model) before the training loop, and check if the model is correctly converted?

ActNN converts nn.Modules with its own modules, and I noticed there are additional model converters after actnn.QModules. If these converters are looking for the original nn.Modules (e.g., nn.BatchNorm), they may not found the corresponding module. You can try moving actnn.QModule after these converters.

if opts.separable_conv and 'plus' in opts.model: network.convert_to_separable_conv(model.model.classifier) utils.set_bn_momentum(model.model.backbone, momentum=0.01)

the codes you list are always skiped during my training, and I print my model, seem all layers are converted correctly:

DataParallel(
(module): QModule(
(model): DeepLabV3(
(backbone): IntermediateLayerGetter(
(conv1): QConv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
(maxpool): QMaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): QConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
(downsample): Sequential(
(0): QConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): QConv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(2): Bottleneck(
(conv1): QConv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): QConv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
(downsample): Sequential(
(0): QConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): QConv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(2): Bottleneck(
(conv1): QConv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(3): Bottleneck(
(conv1): QConv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(128, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): QConv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
(downsample): Sequential(
(0): QConv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): QConv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(2): Bottleneck(
(conv1): QConv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(3): Bottleneck(
(conv1): QConv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(4): Bottleneck(
(conv1): QConv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(5): Bottleneck(
(conv1): QConv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(1024, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): QConv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(2048, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
(downsample): Sequential(
(0): QConv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): QBatchNorm2d(2048, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): QConv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(2048, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
(2): Bottleneck(
(conv1): QConv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): QBatchNorm2d(512, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(conv3): QConv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): QBatchNorm2d(2048, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(relu): QReLU()
)
)
)
(classifier): DeepLabHeadV3Plus(
(project): Sequential(
(0): QConv2d(256, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): QBatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
)
(aspp): ASPP(
(convs): ModuleList(
(0): Sequential(
(0): QConv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
)
(1): ASPPConv(
(0): QConv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
)
(2): ASPPConv(
(0): QConv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
)
(3): ASPPConv(
(0): QConv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
)
(4): ASPPPooling(
(0): AdaptiveAvgPool2d(output_size=1)
(1): QConv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): QReLU()
)
)
(project): Sequential(
(0): QConv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
(3): QDropout(p=0.1, inplace=False)
)
)
(classifier): Sequential(
(0): QConv2d(304, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): QBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): QReLU()
(3): QConv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
)

@cjf00000
Copy link
Collaborator

That's strange. ActNN L0 and full precision training should have identical behavior.

Could you try to debug by the following:

  1. prepare a model checkpoint, and a batch of (data, label)
  2. compute the gradient with full-precision training
  3. compute the gradient with actnn
  4. check if the gradients are identical, if not, replace actnn layers with nn.module layers one by one, and observe which layer is causing the gradient discrepancy.

If you spot a bug in our implementation, please create a PR for us.

@hewiew
Copy link
Author

hewiew commented Nov 1, 2021

That's strange. ActNN L0 and full precision training should have identical behavior.

Could you try to debug by the following:

1. prepare a model checkpoint, and a batch of (data, label)

2. compute the gradient with full-precision training

3. compute the gradient with actnn

4. check if the gradients are identical, if not, replace actnn layers with nn.module layers one by one, and observe which layer is causing the gradient discrepancy.

If you spot a bug in our implementation, please create a PR for us.

Thanks for your advise last week. I've followed the steps you list but it seems we are getting into a trouble:

I restored my model (only a nn.Conv2D layer) from a fixed checkpoint, and I found that the gradient with actnn(set level: "L0") and full-precision training are totally different.

I've sent an email to your Gmail mailbox with my experimental code and printed results with details. You may check that during your free time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants