Skip to content

Commit

Permalink
Linter cleanup patch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621331426
  • Loading branch information
The gemma.cpp Authors authored and Copybara-Service committed Apr 4, 2024
1 parent 7122afe commit f91040c
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 128 deletions.
11 changes: 11 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
# foundation models from Google.

load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
load("@rules_license//rules:license.bzl", "license")

package(
Expand Down Expand Up @@ -132,3 +133,13 @@ cc_binary(
"@hwy//:thread_pool",
],
)

pytype_strict_library(
name = "util/convert_weights",
srcs = ["util/convert_weights.py"],
deps = [
"//third_party/py/gemma",
"//third_party/py/numpy",
"//third_party/py/torch:pytorch",
],
)
271 changes: 143 additions & 128 deletions util/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert model weights from Python library formats to the gemma_cpp format."""

from collections import defaultdict
import torch

import argparse
import collections
import os

# Requires torch 2.2 and gemma package from:
# https://github.com/google/gemma_pytorch
from gemma import config
from gemma import model as gemma_model
import numpy as np
import argparse
import os
import torch


def check_file_exists(path):
if not os.path.exists(str(path)):
raise argparse.ArgumentTypeError(
f"The file {path} does not appear to exist."
)
return path

# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch

def check_file_exists(value):
if not os.path.exists(str(value)):
raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)
return value

def check_model_types(path):
if str(path).lower() not in ["2b", "7b"]:
raise argparse.ArgumentTypeError(
f"Model type path {path} is not in [2b, 7b]."
)
return path

def check_model_types(value):
if str(value).lower() not in ["2b", "7b"]:
raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value)
return value


parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -73,126 +81,133 @@ def check_model_types(value):


TRANSFORMATIONS = {
"2b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
}
),
"7b":defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
}
),
"2b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
"self_attn.o_proj.weight": lambda x: x.reshape(
(2048, 8, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
"7b": collections.defaultdict(
lambda: lambda x: x,
{
"embedder.weight": lambda x: x,
"self_attn.qkv_proj.weight": lambda x: x.reshape(
(3, 16, 256, 3072)
).transpose([1, 0, 2, 3]),
"self_attn.o_proj.weight": lambda x: x.reshape(
(3072, 16, 256)
).transpose([1, 0, 2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
},
),
}

VALIDATIONS = {
"2b": {
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
"input_layernorm.weight": lambda x: x.shape == (2048,),
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
"input_layernorm.weight": lambda x: x.shape == (3072,),
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
},
"2b": {
"embedder.weight": lambda x: x.shape == (256000, 2048),
"model.norm.weight": lambda x: x.shape == (2048,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
"input_layernorm.weight": lambda x: x.shape == (2048,),
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
},
"7b": {
"embedder.weight": lambda x: x.shape == (256000, 3072),
"model.norm.weight": lambda x: x.shape == (3072,),
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
"input_layernorm.weight": lambda x: x.shape == (3072,),
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
},
}


def param_names(num_hidden_layers: int):
"""Return parameter names in the order they are expected for deserialization."""

# note *weight_scaler params are ignored in the forward computation unless
# quantization is being used.
#
# since we are working with the full precision weights as input, don't
# include these in the parameters being iterated over.

# fmt: off
names = [
("embedder.weight", ) * 2, # embedder_input_embedding
("model.norm.weight", ) * 2 # final_norm_scale
]
layer_params = [
"self_attn.o_proj.weight", # attn_vec_einsum_w
"self_attn.qkv_proj.weight", # qkv_einsum_w
"mlp.gate_proj.weight", # gating_einsum_w
"mlp.up_proj.weight",
"mlp.down_proj.weight", # linear_w
"input_layernorm.weight", # pre_attention_norm_scale
"post_attention_layernorm.weight", # pre_ffw_norm_scale
]
# fmt: on
for layer in range(num_hidden_layers):
for layer_param in layer_params:
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
return names


def convert_weights():
model_type = args.model_type
output_file = args.output_file

model_config = config.get_model_config(model_type)
model_config.dtype = "float32"
model_config.tokenizer = args.tokenizer
device = torch.device("cpu")
torch.set_default_dtype(torch.float)
model = gemma_model.GemmaForCausalLM(model_config)

model.load_weights(args.weights)
model.to(device).eval()

model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)

all_ok = True
print("Checking transformations ...")
def param_names(num_hidden_layers: int) -> list[str]:
"""Return parameter names in the order they are expected for deserialization."""

# note *weight_scaler params are ignored in the forward computation unless
# quantization is being used.
#
# since we are working with the full precision weights as input, don't
# include these in the parameters being iterated over.

names = [
("embedder.weight",) * 2, # embedder_input_embedding
("model.norm.weight",) * 2, # final_norm_scale
]
layer_params = [
"self_attn.o_proj.weight", # attn_vec_einsum_w
"self_attn.qkv_proj.weight", # qkv_einsum_w
"mlp.gate_proj.weight", # gating_einsum_w
"mlp.up_proj.weight",
"mlp.down_proj.weight", # linear_w
"input_layernorm.weight", # pre_attention_norm_scale
"post_attention_layernorm.weight", # pre_ffw_norm_scale
]

for layer in range(num_hidden_layers):
for layer_param in layer_params:
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
return names


def convert_weights() -> None:
"""Convert model weights from Python library to gemma_cpp format."""
model_type = args.model_type
output_file = args.output_file

model_config = config.get_model_config(model_type)
model_config.dtype = "float32"
model_config.tokenizer = args.tokenizer
device = torch.device("cpu")
torch.set_default_dtype(torch.float)
model = gemma_model.GemmaForCausalLM(model_config)

model.load_weights(args.weights)
model.to(device).eval()

model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)

any_errors = False
print("Checking transformations ...")
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"

if check == "FAILED":
any_errors = True
print(f" {name : <60}{str(arr.shape) : <20}{check}")

if any_errors:
return None

print("Writing parameters ...")
with open(output_file, "wb") as bin_handle:
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"

if check == "FAILED":
all_ok = False
print(f" {name : <60}{str(arr.shape) : <20}{check}")

if all_ok:
print("Writing parameters ...")
gate = None
with open(output_file, "wb") as bin_handle:
for name, layer_name in param_order:
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
print(f" {name : <60}{str(arr.shape) : <20}{check}")
arr.flatten().astype(np.float32).tofile(bin_handle)
arr = model_dict[name].detach().numpy()
arr = TRANSFORMATIONS[model_type][layer_name](arr)
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
print(f" {name : <60}{str(arr.shape) : <20}{check}")
arr.flatten().astype(np.float32).tofile(bin_handle)


if __name__ == "__main__":
convert_weights()
print("Done")
convert_weights()
print("Done")

0 comments on commit f91040c

Please sign in to comment.