Skip to content

cmpark0126/pytorch-LARS

Repository files navigation

Pytorch-LARS

Objective

Requirements

  • python == 3.6.8
  • pytorch >= 1.1.0
  • cuda >= 10
  • matplotlib >= 3.1.0 (option)
  • etc.

Usage

  • Train
$ git clone https://github.com/cmpark0126/pytorch-LARS.git
$ cd pytorch-LARS/
$ vi hyperparams.py # 학습을 위해 Basic, Hyperparams class 수정
$ python train.py # CIFAR10 학습 시작
  • Evaluate
$ vi hyperparams.py # 학습 결과 확인을 위해 Hyperparams_for_val class 조정, 특정한 checkpoint를 선택하는 것이 가능
$ python val.py # 학습 결과 확인, 이걸로 학습 진행 도중 update되어온 test accuracy의 history 확인 가능

Hyperparams (hyperparams.py)

  • Base (class)

    • batch_size: 기준 Batch size. 실험에서 사용되는 모든 Batch size는 이 size의 배수 형태로 나타난다.

    • lr: 기준 learning rate. 일반적으로 linear scailing에서 기준 값으로 사용한다.

    • multiples: 아래에서 설명되는 k를 구하기 위한 지수로 사용되는 배수이다.

  • Hyperparams (class)

    • batch_size: 실제 학습에서 사용하는 batch size

    • lr: 실제 학습에서 초기 값으로 사용하는 learning rate

    • momentum

    • weight_decay

    • trust_coef: trust coefficient로 LARS 사용시에 내부에서 구해지는 Local LR의 신뢰도를 의미

    • warmup_multiplier

    • warmup_epoch

    • max_decay_epoch: polynomial decay를 최대한 진행할 epoch 수

    • end_learning_rate: decay 작업이 모두 완료되었을 때 learning rate가 수렴될 값

    • num_of_epoch: 학습을 돌릴 총 epoch 수

    • with_lars

  • Hyperparams_for_val (class)

    • checkpoint_folder_name: hyperparams.py와 같은 폴더에는 파라미터를 모아둔 checkpoint folder가 존재해야 하며, 이들 중 하나의 이름을 지정(eg. checkpoint_folder_name = 'checkpoint-attempt1')

    • with_lars: checkpoint 중, lars를 사용한 것 혹은 사용하지 않은 것을 선택

    • batch_size: checkpoint 중, 사용한 batch_size 크기를 지정

    • device: evaluation을 위해 모델을 돌릴 때 사용할 cuda device 선택

Demonstration

  • Terminology
    • k
      • we increase the batch B by k
      • start batch size is 128
      • if we use 256 as batch size, k is 2 in this time
      • k = (2 ** (multiples - 1))
    • (base line)
      • target accuracy which we want to get when we train the model using large batch size with LARS

Attempt 1

  • Configuration

    • Hyperparams

      • momentum = 0.9

      • weigth_decay

        • noLars -> 5e-04
        • withLARS -> 5e-03
      • warm-up for 5 epoch

        • warmup_multiplier = k
        • target lr follows linear scailing rule
      • polynomial decay (power=2) LR policy (after warm-up)

        • for 200 epoch
        • minimum lr = 1.5e-05 * k
      • number of epoch = 200

  • Without LARS

Batch Base LR top-1 Accuracy, % Time to train
128 0.15 89.15 %
(base line)
2113.52 sec
256 0.15 88.43 % 1433.38 sec
512 0.15 88.72 % 1820.35 sec
1024 0.15 87.96 % 1303.54 sec
2048 0.15 87.05 % 1827.90 sec
4096 0.15 78.03 % 2083.24 sec
8192 0.15 14.59 % 1459.81 sec
  • With LARS (closest one to base line, for comparing time to train)
Batch Base LR top-1 Accuracy, % Time to train
128 0.15 89.16 % 3203.54 sec
256 0.15 89.19 % 2147.74 sec
512 0.15 89.29 % 1677.25 sec
1024 0.15 89.17 % 1604.91 sec
2048 0.15 88.70 % 1413.10 sec
4096 0.15 86.78 % 1609.08 sec
8192 0.15 80.85 % 1629.48 sec
  • With LARS (best accuracy)
Batch Base LR top-1 Accuracy, % Time to train
128 0.15 89.62 % 3606.08 sec
256 0.15 89.78 % 2675.04 sec
512 0.15 89.38 % 1712.90 sec
1024 0.15 89.22 % 1967.92 sec
2048 0.15 88.70 % 1413.10 sec
4096 0.15 86.78 % 1609.08 sec
8192 0.15 80.85 % 1629.48 sec

Attempt 2

  • Configuration

    • Hyperparams

      • momentum = 0.9

      • weigth_decay

        • noLars -> 5e-04
        • withLARS -> 5e-03
      • trust coefficient = 0.1

      • warm-up for 5 epoch

        • warmup_multiplier = 2 * k
        • target lr follows linear scailing rule
      • polynomial decay (power=2) LR policy (after warm-up)

        • for 200 epoch
        • minimum lr = 1e-05
      • number of epoch = 200

  • Without LARS

Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.40 %
(base line)
4232.56 sec
256 0.05 90.00 % 2968.43 sec
512 0.05 89.50 % 2707.79 sec
1024 0.05 89.27 % 2627.22 sec
2048 0.05 89.21 % 2500.02 sec
4096 0.05 84.73 % 2872.25 sec
8192 0.05 20.85 % 2923.95 sec
  • With LARS (closest one to base line, for comparing time to train)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.21 % 6792.61 sec
256 0.05 90.28 % 4871.68 sec
512 0.05 90.41 % 3581.32 sec
1024 0.05 90.27 % 3030.45 sec
2048 0.05 90.19 % 2773.21 sec
4096 0.05 88.49 % 2866.02 sec
8192 0.05 62.20 % 1312.98 sec
  • With LARS (best accuracy)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.21 % 6792.61 sec
256 0.05 90.28 % 4871.68 sec
512 0.05 90.41 % 3581.32 sec
1024 0.05 90.27 % 3030.45 sec
2048 0.05 90.19 % 2773.21 sec
4096 0.05 88.49 % 2866.02 sec
8192 0.05 62.20 % 1312.98 sec

Attempt 3

  • Configuration

    • Hyperparams

      • momentum = 0.9

      • weigth_decay

        • noLars -> 5e-04
        • withLARS -> 5e-03
      • trust coefficient = 0.1

      • warm-up for 5 epoch

        • warmup_multiplier = 2
      • polynomial decay (power=2) LR policy (after warm-up)

        • for 200 epoch
        • minimum lr = 1e-05 * k
      • number of epoch = 200

    • Additional Jobs

      • Use He initialization

      • base lr은 linear scailing rule에 따라 조정

  • Without LARS

Batch Base LR top-1 Accuracy, % Time to train
128 0.05 89.76 % 3983.89 sec
256 0.1 90.08 %
(base line)
3095.91 sec
512 0.2 89.34 % 2674.38 sec
1024 0.4 88.82 % 2581.19 sec
2048 0.8 89.29 % 2660.56 sec
4096 1.6 85.02 % 2871.04 sec
8192 3.2 77.72 % 3195.90 sec
  • With LARS (closest one to base line, for comparing time to train)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.11 % 6880.76 sec
256 0.1 90.12 % 4262.83 sec
512 0.2 90.11 % 3548.07 sec
1024 0.4 90.02 % 2760.31 sec
2048 0.8 90.09 % 2877.81 sec
4096 1.6 88.38 % 2946.53 sec
8192 3.2 86.40 % 3260.45 sec
  • With LARS (best accuracy)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.37 % 7338.71 sec
256 0.1 90.32 % 4590.58 sec
512 0.2 90.11 % 3548.07 sec
1024 0.4 90.50 % 2897.45 sec
2048 0.8 90.09 % 2877.81 sec
4096 1.6 88.38 % 2946.53 sec
8192 3.2 86.40 % 3260.45 sec

Attempt 4

  • Configuration

    • Hyperparams

      • momentum = 0.9

      • weigth_decay

        • noLars -> 5e-04
        • withLARS -> 5e-03
      • trust coefficient = 0.1

      • warm-up for 5 epoch

        • warmup_multiplier = 5
      • polynomial decay (power=2) LR policy (after warm-up)

        • for 200 epoch
        • minimum lr = 1e-05 * k
      • number of epoch = 200

    • Additional Jobs

      • Use He initialization

      • base lr은 linear scailing rule에 따라 조정

  • Without LARS

Batch Base LR top-1 Accuracy, % Time to train
128 0.02 89.84 % 4146.52 sec
256 0.04 90.22 %
(base line)
3023.48 sec
512 0.08 89.42 % 2588.01 sec
1024 0.16 89.41 % 2494.35 sec
2048 0.32 88.97 % 2616.32 sec
4096 0.64 85.13 % 2872.76 sec
8192 1.28 75.99 % 3226.53 sec
  • With LARS (closest one to base line, for comparing time to train)
Batch Base LR top-1 Accuracy, % Time to train
128 0.02 90.20 % 6740.03 sec
256 0.04 90.25 % 4662.09 sec
512 0.08 90.24 % 3381.99 sec
1024 0.16 90.07 % 2929.32 sec
2048 0.32 89.82 % 2908.37 sec
4096 0.64 88.09 % 2980.63 sec
8192 1.28 86.56 % 3314.60 sec
  • With LARS (best accuracy)
Batch Base LR top-1 Accuracy, % Time to train
128 0.02 90.69 % 7003.00 sec
256 0.04 90.32 % 4808.80 sec
512 0.08 90.40 % 3615.13 sec
1024 0.16 90.07 % 2929.32 sec
2048 0.32 89.82 % 2908.37 sec
4096 0.64 88.09 % 2980.63 sec
8192 1.28 86.56 % 3314.60 sec

Attempt 5

  • Configuration

    • Hyperparams

      • momentum = 0.9

      • weigth_decay

        • noLars -> 5e-04
        • withLARS -> 5e-03
      • trust coefficient = 0.1

      • warm-up for 5 epoch

        • warmup_multiplier = 2
      • polynomial decay (power=2) LR policy (after warm-up)

        • for 175 epoch
        • minimum lr = 1e-05 * k
      • number of epoch = 175

    • Additional Jobs

      • Use He initialization

      • base lr은 linear scailing rule에 따라 조정

  • Without LARS

Batch Base LR top-1 Accuracy, % Time to train
128 0.05 89.50 %
(base line)
3682.72 sec
256 0.1 89.22 % 2678.24 sec
512 0.2 89.12 % 2337.15 sec
1024 0.4 88.70 % 2282.48 sec
2048 0.8 88.89 % 2316.96 sec
4096 1.6 86.87 % 2515.56 sec
8192 3.2 15.50 % 2783.00 sec
  • With LARS (closest one to base line, for comparing time to train)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 89.56 % 5445.55 sec
256 0.1 89.52 % 3461.59 sec
512 0.2 89.60 % 2738.91 sec
1024 0.4 89.50 % 2410.23 sec
2048 0.8 89.42 % 2474.93 sec
4096 1.6 88.43 % 2618.97 sec
8192 3.2 74.96 % 1835.32 sec
  • With LARS (best accuracy)
Batch Base LR top-1 Accuracy, % Time to train
128 0.05 90.36 % 6377.71 sec
256 0.1 90.18 % 4219.26 sec
512 0.2 90.08 % 3130.41 sec
1024 0.4 89.94 % 2578.00 sec
2048 0.8 89.42 % 2474.93 sec
4096 1.6 88.43 % 2618.97 sec
8192 3.2 74.96 % 1835.32 sec

Visualization

<Fig1. Attempt4, Without LARS, Batch size = 8192>

<Fig2. Attempt4, With LARS, Batch size = 8192>

  • <Fig1>과 <Fig2>를 비교하면 LARS를 사용할 때, 좀 더 안정적으로 학습을 시작하고, 부드럽게 accuracy가 증가하는 것을 확인할 수 있다.

  • Attempt3, 4, 5를 작업하면서 만든 Accuracy 변화율 그래프는 아래 링크에서 확인하는 것이 가능하다.

Analysis of Resnet50 Training With Large Batch (CIFAR10)

  • LARS를 사용하면 1024까지의 Batch를 사용해서 모델이 Base line의 성능을 보일 수 있도록 학습하는 것이 가능하다는 것을 확인

  • LARS만을 사용하는 것보다, He initialization을 포함하여 여러 테크닉을 함께 사용하는 것이 중요하다는 것을 확인

  • LARS를 사용하면 단순히 base line을 만족하는 것이 아니라 더 좋은 성능을 보일 수도 있다는 것을 확인

    • Local learning rate가 vanishing 문제나 exploding gradient 문제를 완화시킨다는 논문의 언급에 따른 부가 효과로 보임

Open Issue

  • LARS를 사용하면 약 두 배 정도 시간이 더 들어가는 것을 확인. 학습 시간을 줄일 수 있는 방안이 있는지 찾아보기

Reference

Appendix

val.py 실행 화면

  • best accuracy가 update되어 온 history를 확인할 수 있다.

Releases

No releases published

Packages

No packages published

Languages