# Copyright 2024 Google LLC # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import torch from gemma import config from gemma import model as gemma_model import numpy as np import argparse import os # Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch def check_file_exists(value): if not os.path.exists(str(value)): raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) return value def check_model_types(value): if str(value).lower() not in ["2b", "7b"]: raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) return value parser = argparse.ArgumentParser() parser.add_argument( "--tokenizer", dest="tokenizer", default="models/tokenizer.spm", help="Location of tokenizer file (.model or .spm)", type=check_file_exists, ) parser.add_argument( "--weights", dest="weights", default="models/gemma-2b-it.ckpt", help="Location of input checkpoint file (.ckpt)", type=check_file_exists, ) parser.add_argument( "--output_file", dest="output_file", default="2bit-f32.sbs", help="Location to write converted weights", type=str, ) parser.add_argument( "--model_type", dest="model_type", default="2b", help="Model size / type (2b, 7b)", type=check_model_types, ) args = parser.parse_args() def expand_qkv(qkv_proj: np.array) -> np.array: """This won't be needed anymore when MQA is implemented""" assert qkv_proj.shape == (2560, 2048) qkv = qkv_proj.reshape((10, 256, 2048)) q_proj = qkv[:8].reshape((1,8,256,2048)) kv_proj = qkv[8:] kv_proj = kv_proj[:, np.newaxis, :, :] kv_proj = np.repeat(kv_proj, 8, axis=1) qkv = np.concatenate([q_proj, kv_proj]) qkv = np.transpose(qkv, axes=[1,0,2,3]) return qkv TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, } ), "7b":defaultdict( lambda: lambda x: x, { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, } ), } VALIDATIONS = { "2b": { "embedder.weight": lambda x: x.shape == (256128, 2048), "model.norm.weight": lambda x: x.shape == (2048,), "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), "input_layernorm.weight": lambda x: x.shape == (2048,), "post_attention_layernorm.weight": lambda x: x.shape == (2048,), }, "7b": { "embedder.weight": lambda x: x.shape == (256128, 3072), "model.norm.weight": lambda x: x.shape == (3072,), "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), "input_layernorm.weight": lambda x: x.shape == (3072,), "post_attention_layernorm.weight": lambda x: x.shape == (3072,), }, } def param_names(num_hidden_layers: int): """Return parameter names in the order they are expected for deserialization.""" # note *weight_scaler params are ignored in the forward computation unless # quantization is being used. # # since we are working with the full precision weights as input, don't # include these in the parameters being iterated over. # fmt: off names = [ ("embedder.weight", ) * 2, # embedder_input_embedding ("model.norm.weight", ) * 2 # final_norm_scale ] layer_params = [ "self_attn.o_proj.weight", # attn_vec_einsum_w "self_attn.qkv_proj.weight", # qkv_einsum_w "mlp.gate_proj.weight", # gating_einsum_w "mlp.up_proj.weight", "mlp.down_proj.weight", # linear_w "input_layernorm.weight", # pre_attention_norm_scale "post_attention_layernorm.weight", # pre_ffw_norm_scale ] # fmt: on for layer in range(num_hidden_layers): for layer_param in layer_params: names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] return names def convert_weights(): model_type = args.model_type output_file = args.output_file model_config = config.get_model_config(model_type) model_config.dtype = "float32" model_config.tokenizer = args.tokenizer device = torch.device("cpu") torch.set_default_dtype(torch.float) model = gemma_model.GemmaForCausalLM(model_config) model.load_weights(args.weights) model.to(device).eval() model_dict = dict(model.named_parameters()) param_order = param_names(model_config.num_hidden_layers) all_ok = True print("Checking transformations ...") for name, layer_name in param_order: arr = model_dict[name].detach().numpy() arr = TRANSFORMATIONS[model_type][layer_name](arr) check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" if check == "FAILED": all_ok = False print(f" {name : <60}{str(arr.shape) : <20}{check}") if all_ok: print("Writing parameters ...") gate = None with open(output_file, "wb") as bin_handle: for name, layer_name in param_order: arr = model_dict[name].detach().numpy() arr = TRANSFORMATIONS[model_type][layer_name](arr) check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" print(f" {name : <60}{str(arr.shape) : <20}{check}") arr.flatten().astype(np.float32).tofile(bin_handle) if __name__ == "__main__": convert_weights() print("Done")