From 5be9a2243f9b4e6c3d3194e12a353e753c281324 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 1 Mar 2024 15:52:51 -0500 Subject: [PATCH] initial (wip) convert_weights script from pytorch --- util/convert_weights.py | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 util/convert_weights.py diff --git a/util/convert_weights.py b/util/convert_weights.py new file mode 100644 index 0000000..0749cad --- /dev/null +++ b/util/convert_weights.py @@ -0,0 +1,55 @@ +# WIP - DO NOT MERGE + +import torch +from gemma import config +from gemma import model as gemma_model +import numpy as np + + +def param_names(): + """Return parameter names in the order they are expected for deserialization.""" + names = ["embedder.weight", "model.norm.weight"] + # note *weight_scaler params are ignored in the forward computation unless quantization is being used. + # since we are working with the full precision weights as input, don't include these in the parameters being iterated over + layer_params = [ + "self_attn.qkv_proj.weight", # attn_vec_einsum_w + "self_attn.o_proj.weight", # qkv_einsum_w + "mlp.gate_proj.weight", # qkv_einsum_w + "mlp.up_proj.weight", # gating_einsum_w + "mlp.down_proj.weight", # linear_w + "input_layernorm.weight", # pre_attention_norm_scale + "post_attention_layernorm.weight", # pre_ffw_norm_scale + ] + for layer in range(18): + for layer_param in layer_params: + names = names + ["model.layers." + str(layer) + "." + layer_param] + return names + + +def convert_weights(): + # TODO(austinvhuang): move code in here + pass + + +if __name__ == "__main__": + # TODO(austinvhuang): parameterize paths + output_file = "2bit-f32.sbs" + model_config = config.get_model_config("2b") + model_config.dtype = "float32" + model_config.quant = "store_true" + model_config.tokenizer = "models/tokenizer.spm" + device = torch.device("cpu") + torch.set_default_dtype(torch.float) + model = gemma_model.GemmaForCausalLM(model_config) + model.load_weights("models/gemma-2b-it.ckpt") + model_dict = dict(model.named_parameters()) + param_order = param_names() + print("Writing parameters ...") + with open(output_file, "wb") as bin_handle: + for name in param_order: + arr = model_dict[name].detach().numpy() + # TODO(austinvhuang): reshapes + print(f" {name : <60}{str(arr.shape)}") + arr.flatten().astype(np.float32).tofile(bin_handle) + + print("Done")