Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QAT cause conversion error #106

Open
Raychen0617 opened this issue Aug 12, 2022 · 13 comments
Open

QAT cause conversion error #106

Raychen0617 opened this issue Aug 12, 2022 · 13 comments
Labels
question Further information is requested

Comments

@Raychen0617
Copy link

Sorry for bothering again, when I use QAT on yolov5, the conversion of the model will cause significant errors. I compare them layer by layer and find out it starts to cause errors from the very first layer and I have no idea how to fix it.

The generated script is here autoshape_qat.zip

device = "cpu"
dummy_input = torch.rand(1, 3, 640, 640)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

for k, m in model.named_modules():
    if isinstance(m, Detect):
        m.inplace = False
        m.onnx_dynamic = True
        m.export = True

quantizer = QATQuantizer(model, dummy_input, work_dir='out', config={'asymmetric': True, 'per_tensor': True})
qat_model = quantizer.quantize()
qat_model.to(device=device)

with torch.no_grad():  
    qat_model.cpu()
    qat_model = torch.quantization.convert(qat_model)
    torch.backends.quantized.engine = quantizer.backend
    convert_and_compare(qat_model, './output/qat_model.tflite', dummy_input)
@peterjc123 peterjc123 added the question Further information is requested label Aug 12, 2022
@peterjc123
Copy link
Collaborator

peterjc123 commented Aug 12, 2022

Sorry for bothering again, when I use QAT on yolov5, the conversion of the model will cause significant errors. I compare them layer by layer and find out it starts to cause errors from the very first layer and I have no idea how to fix it.

The generated script is here autoshape_qat.zip

device = "cpu"
dummy_input = torch.rand(1, 3, 640, 640)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

for k, m in model.named_modules():
    if isinstance(m, Detect):
        m.inplace = False
        m.onnx_dynamic = True
        m.export = True

quantizer = QATQuantizer(model, dummy_input, work_dir='out', config={'asymmetric': True, 'per_tensor': True})
qat_model = quantizer.quantize()
qat_model.to(device=device)

with torch.no_grad():  
    qat_model.cpu()
    qat_model = torch.quantization.convert(qat_model)
    torch.backends.quantized.engine = quantizer.backend
    convert_and_compare(qat_model, './output/qat_model.tflite', dummy_input)

Wow, you didn't even train the QAT model. I don't think that's gonna work because if don't train that model, the error caused by quantization (mainly due to the differences in rounding methods) between different frameworks could be larger than expected.

@Raychen0617
Copy link
Author

I tried to train the model for 100 epochs on coco128 before, however, the error did not seem to be smaller. So maybe the error is caused by the training process?

@peterjc123
Copy link
Collaborator

I tried to train the model for 100 epochs on coco128 before, however, the error did not seem to be smaller. So maybe the error is caused by the training process?

Could you please share how you performed the training process with our tool?

@Raychen0617
Copy link
Author

Raychen0617 commented Aug 15, 2022

The code is slightly modified from https://github.com/ultralytics/yolov5/blob/master/train.py, I try to make it simpler and easier to read. In short, I comment out amp scaler and ema in the training code and add the new function ComputeLossQuant which is just similar to the original loss function but uses the quantized model's output. Hope it is not too messy to read, thanks in advance !!

# Model 
ckpt = torch.load(weights, map_location='cpu')
model = ckpt['model'].float().cuda()
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect      
model.load_state_dict(csd)  # load

# Image size
gs = max(int(model.stride.max()), 32)  # grid size (max stride)
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2)  # verify imgsz is gs-multiple

nl = de_parallel(model).model[-1].nl  # number of detection layers (to scale hyps)
hyp['box'] *= 3 / nl  # scale to layers
hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
hyp['label_smoothing'] = opt.label_smoothing

# TinyNN qat 
for k, m in model.named_modules():
    if isinstance(m, Detect):
        m.inplace = False
        m.onnx_dynamic = True
        m.export = True

dummy_input = torch.rand(1, 3, 640, 640)
quantizer = QATQuantizer(model, dummy_input, work_dir='./quantization', config={'asymmetric': True, 'per_tensor': True})
qat_model = quantizer.quantize()
device = "cuda:0"
qat_model.to(device=device)

# Batch size
if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
    batch_size = check_train_batch_size(model, imgsz, amp)
    loggers.on_params_update({"batch_size": batch_size})

# Optimizer
nbs = 64  # nominal batch size
accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
optimizer = smart_optimizer(qat_model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])

# Scheduler
lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']  # linear
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)

# Trainloader
train_loader, dataset = create_dataloader(train_path,
                                          imgsz,
                                          batch_size // WORLD_SIZE,
                                          gs,
                                          single_cls,
                                          hyp=hyp,
                                          augment=True,
                                          cache=None if opt.cache == 'val' else opt.cache,
                                          rect=opt.rect,
                                          rank=LOCAL_RANK,
                                          workers=workers,
                                          image_weights=opt.image_weights,
                                          quad=opt.quad,
                                          prefix=colorstr('train: '),
                                          shuffle=True)
labels = np.concatenate(dataset.labels, 0)

# Training 
compute_loss = ComputeLossQuant(model, qat_model)
for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
      
      qat_model.train()

      mloss = torch.zeros(3, device=device)  # mean losses
      if RANK != -1:
          train_loader.sampler.set_epoch(epoch)
      pbar = enumerate(train_loader)

      if RANK in {-1, 0}:
          pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress bar
     
      optimizer.zero_grad()
      for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
          
          ni = i + nb * epoch  # number integrated batches (since train start)
          imgs = imgs.to(device, non_blocking=True).float() / 255  # uint8 to float32, 0-255 to 0.0-1.0

          # Warmup
          if ni <= nw:
              xi = [0, nw]  # x interp
              # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
              accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
              for j, x in enumerate(optimizer.param_groups):
                  # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                  x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
                  if 'momentum' in x:
                      x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

          # Forward
          #with torch.cuda.amp.autocast(amp):
          pred = qat_model(imgs)  # forward
          loss, loss_items = compute_loss(pred, targets.to(device))  # loss scaled by batch_size
          if RANK != -1:
              loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode

          # Backward
          loss.backward()

          # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
          if ni - last_opt_step >= accumulate:
              torch.nn.utils.clip_grad_norm_(qat_model.parameters(), max_norm=10.0)  # clip gradients
              optimizer.step()
              optimizer.zero_grad()
              last_opt_step = ni

          # Log
          if RANK in {-1, 0}:
              mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
              mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
              if callbacks.stop_training:
                  return
          # end batch ------------------------------------------------------------------------------------------------

      # Scheduler
      lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
      scheduler.step()

@peterjc123
Copy link
Collaborator

peterjc123 commented Aug 15, 2022

@Raychen0617 The process seems to be correct. So I took a look at your model, it seems that the QAT rewrite in our repo doesn't handle nn.SiLU correctly so that there are a lot of QuantStubs and DeQuantStubs around them. Take the following code as an example, as you can see the Quantize and Dequantize are added before and after the activation, which (1) slows down the actual inference on real-world devices (2) makes the difference larger between TFL and PT, because the rounding method is different in two frameworks (for PT, rounding mode is RTNE, that is nearbyint() in C, buf for TF/TFL, it is RTN, that is round() in C) and accumulating these layers make things worse

class AutoShape_qat(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # ...
        self.model_model_model_0_conv = torch.nn.Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2))
        self.fake_dequant_inner_0_0 = torch.quantization.DeQuantStub()
        self.model_model_model_0_act = torch.nn.SiLU(inplace=True)
        self.fake_quant_inner_56_0 = torch.quantization.QuantStub()
        # ...

    def forward(self, input_0_f):
        # ...
        model_model_model_0_conv = self.model_model_model_0_conv(type_as_0_f)
        type_as_0_f = None
        fake_dequant_inner_0_0 = self.fake_dequant_inner_0_0(model_model_model_0_conv)
        model_model_model_0_conv = None
        model_model_model_0_act = self.model_model_model_0_act(fake_dequant_inner_0_0)
        fake_dequant_inner_0_0 = None
        fake_quant_inner_56_0 = self.fake_quant_inner_56_0(model_model_model_0_act)
        model_model_model_0_act = None
        # ...

@peterjc123
Copy link
Collaborator

@Raychen0617 You may try out #107, which is the fix for the problem I described above.

@Raychen0617
Copy link
Author

Hi, thanks for helping. After the modification, the error did become lower, and the problem described above is solved in the newly generated python script autoshape_qat.zip. However, the conversion error is still too large to be ignored. In addition, if I replace all the SiLU with Relu, the error will be even lower, but the error is still large and will strongly affect the result.

@peterjc123
Copy link
Collaborator

peterjc123 commented Aug 16, 2022

@Raychen0617 So my question is that is the quantization aware training performing well? If it is and the difference is still too large, would you please provide your model (the generated script together with the trained weight file) and one simple input (processed input would be preferred, or you may tell us how to preprocess the image) so that we can take a look?

@Raychen0617
Copy link
Author

QAT can run without error but it doesn't seem to lower the conversion error. The model trained weight and generated script are here model zip. For a simple input, you can use the code from here which is basically calling the function to load an image from the source (coco dataset in my use case) and resize it to the correct shape (1,3,640,640). Hope it is clear enough and if there is any problem, please contact me, thank you very much !!!

@peterjc123
Copy link
Collaborator

QAT can run without error

I mean what about the accuracy of the model? Does that drop a lot compared to the floating one?

@Raychen0617
Copy link
Author

Thanks for replying. As the FAQ mentioned, yolov5 has different outputs during training and evaluation. However, when using Tinynn quantizer to convert the evaluation model, it will show error rewrite supports torch.stack with nodes with exact one input, I think there are some operations not supported in yolov5's validation mode. Thus, when processing qat, I tried to load the qat_train_model weight back to the original model and use it to validate. I know it seems quite strange, but the validation score did increased during the qat process, so I guess QAT is working well (Not sure if it makes sense)

@peterjc123
Copy link
Collaborator

peterjc123 commented Aug 18, 2022

The way to do validation is not right I guess. It is pointless to run the model with weights that is used in quantization-aware training.

However, when using Tinynn quantizer to convert the evaluation model, it will show error rewrite supports torch.stack with nodes with exact one input

How can I reproduce this? Could you please upload a piece of code for this?

@Raychen0617
Copy link
Author

I change this line to make sure the model will go to evaluation mode, and then use the converter to convert it. And I did turn inplace into false.

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
            if True:
            #if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid()
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants