diff --git a/util/convert_weights.py b/util/convert_weights.py new file mode 100644 index 0000000..bd6750a --- /dev/null +++ b/util/convert_weights.py @@ -0,0 +1,212 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +import torch +from gemma import config +from gemma import model as gemma_model +import numpy as np +import argparse +import os + +# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch + +def check_file_exists(value): + if not os.path.exists(str(value)): + raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) + return value + + +def check_model_types(value): + if str(value).lower() not in ["2b", "7b"]: + raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) + return value + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--tokenizer", + dest="tokenizer", + default="models/tokenizer.spm", + help="Location of tokenizer file (.model or .spm)", + type=check_file_exists, +) + +parser.add_argument( + "--weights", + dest="weights", + default="models/gemma-2b-it.ckpt", + help="Location of input checkpoint file (.ckpt)", + type=check_file_exists, +) + +parser.add_argument( + "--output_file", + dest="output_file", + default="2bit-f32.sbs", + help="Location to write converted weights", + type=str, +) + +parser.add_argument( + "--model_type", + dest="model_type", + default="2b", + help="Model size / type (2b, 7b)", + type=check_model_types, +) + +args = parser.parse_args() + + +def expand_qkv(qkv_proj: np.array) -> np.array: + """This won't be needed anymore when MQA is implemented""" + assert qkv_proj.shape == (2560, 2048) + qkv = qkv_proj.reshape((10, 256, 2048)) + + q_proj = qkv[:8].reshape((1,8,256,2048)) + kv_proj = qkv[8:] + kv_proj = kv_proj[:, np.newaxis, :, :] + kv_proj = np.repeat(kv_proj, 8, axis=1) + + qkv = np.concatenate([q_proj, kv_proj]) + qkv = np.transpose(qkv, axes=[1,0,2,3]) + return qkv + +TRANSFORMATIONS = { + "2b":defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), + "self_attn.qkv_proj.weight": expand_qkv, + "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.down_proj.weight": lambda x: x, + } + ), + "7b":defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), + "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), + "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.down_proj.weight": lambda x: x, + } + ), +} + +VALIDATIONS = { + "2b": { + "embedder.weight": lambda x: x.shape == (256128, 2048), + "model.norm.weight": lambda x: x.shape == (2048,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), + "input_layernorm.weight": lambda x: x.shape == (2048,), + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), + }, + "7b": { + "embedder.weight": lambda x: x.shape == (256128, 3072), + "model.norm.weight": lambda x: x.shape == (3072,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), + "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), + "input_layernorm.weight": lambda x: x.shape == (3072,), + "post_attention_layernorm.weight": lambda x: x.shape == (3072,), + }, +} + + +def param_names(num_hidden_layers: int): + """Return parameter names in the order they are expected for deserialization.""" + + # note *weight_scaler params are ignored in the forward computation unless + # quantization is being used. + # + # since we are working with the full precision weights as input, don't + # include these in the parameters being iterated over. + + # fmt: off + names = [ + ("embedder.weight", ) * 2, # embedder_input_embedding + ("model.norm.weight", ) * 2 # final_norm_scale + ] + layer_params = [ + "self_attn.o_proj.weight", # attn_vec_einsum_w + "self_attn.qkv_proj.weight", # qkv_einsum_w + "mlp.gate_proj.weight", # gating_einsum_w + "mlp.up_proj.weight", + "mlp.down_proj.weight", # linear_w + "input_layernorm.weight", # pre_attention_norm_scale + "post_attention_layernorm.weight", # pre_ffw_norm_scale + ] + # fmt: on + for layer in range(num_hidden_layers): + for layer_param in layer_params: + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] + return names + + +def convert_weights(): + model_type = args.model_type + output_file = args.output_file + + model_config = config.get_model_config(model_type) + model_config.dtype = "float32" + model_config.tokenizer = args.tokenizer + device = torch.device("cpu") + torch.set_default_dtype(torch.float) + model = gemma_model.GemmaForCausalLM(model_config) + + model.load_weights(args.weights) + model.to(device).eval() + + model_dict = dict(model.named_parameters()) + param_order = param_names(model_config.num_hidden_layers) + + all_ok = True + print("Checking transformations ...") + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" + + if check == "FAILED": + all_ok = False + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if all_ok: + print("Writing parameters ...") + gate = None + with open(output_file, "wb") as bin_handle: + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") + arr.flatten().astype(np.float32).tofile(bin_handle) + + +if __name__ == "__main__": + convert_weights() + print("Done")