Skip to content

TCLResearchEurope/ptdeco

Repository files navigation

license check test Code style: black Type checker: mypy

ptdeco

ptdeco is a library for model optimization by decomposition built on top of PyTorch.

Table of contents

Introduction

Currently, ptdeco implements the following methods:

  • lockd - method based on local knowledge distillation, tested on vision models (lockd = LOCal Knowledge Distillation)

  • falor - method based on low-rank decomposition of features inspired by Compressing Transformers: Features Are Low-Rank, but Weights Are Not! by Yu Hao, Wu Jianxin (2023), tested on vision models (falor = Features Are LOw Rank)

  • dwain - iterative method based on low-rank decomposition of features, tested on Large Language Models (dwain = Decomposing Weights Algorithm - an Iterative techNique)

lockd method requires short (~ 10 ImageNet epochs) knowledge distillation pretraining before decomposition is made. It can decompose linear layers and convolutions.

falor method does not require pretraining. Model decomposition lasts < 1 GPU hour (depending on model size and parameters). It can decompose linear layers and 1x1 convolutions.

dwain method does not require pretraining. It can decompose linear layers and 1x1 convolutions.

Installation

pip install ptdeco

Saving and loading a decomposed model

Saving a decomposed model

As a result of decomposition you get decompose_config dictionary. You need to serialize this e.g. to JSON. This will let you recreate the structure of a decomposed model. Except this, you need to save state_dict to recover the weights of a decomposed model. The code below illustrates the procedure:

import json
import pathlib

# Your decomposition code

output_path = pathlib.Path("YOUR/CHEKCPOINT/DIRECTORY")
out_decompose_config_path = output_path / "decompose_config.json"
with open(out_decompose_config_path, "wt") as f:
    json.dump(decompose_config, f)
out_decompose_state_dict_path = output_path / "decompose_state_dict.pt"
torch.save(model.state_dict(), out_decompose_state_dict_path)

Loading a decomposed model

To load the model, you need to recreate the original model first. Next, you load and apply the decompose_config. Finally, you load the state_dict (note the state dict "fits" the decomposed model, so you need to do it as a last step). The code below illustrates the procedure:

import json
import pathlib

import ptdeco

model = ... # Build original model
device = ...     # Specify the device original model uses

output_path = pathlib.Path("YOUR/CHEKCPOINT/DIRECTORY")

with open(output_path / "decompose_config.json", "rt") as f:
        decompose_config = json.load(f)

ptdeco.utils.apply_decompose_config_in_place(model, decompose_config)

sd = torch.load(output_path / "decompose_state_dict.pt")

model.load_state_dict(sd, map_location=device)

# Now `model` is decomposed and contains appropriate weights