From 7d7d43e661ed063b071f473747c2960cf7cc18bd Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sat, 2 Mar 2024 08:11:55 -0500 Subject: [PATCH] converter transformations (wip) --- util/convert_weights.py | 54 ++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index 0749cad..2e96edd 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -1,37 +1,53 @@ # WIP - DO NOT MERGE +from collections import defaultdict import torch from gemma import config from gemma import model as gemma_model import numpy as np +TRANSFORMATIONS = defaultdict(lambda: lambda x: x, { + "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + "self_attn.qkv_proj.weight": lambda x: x, + "mlp.up_proj.weight" : lambda x: x, + "mlp.down_proj.weight" : lambda x: x, +}) + def param_names(): """Return parameter names in the order they are expected for deserialization.""" - names = ["embedder.weight", "model.norm.weight"] - # note *weight_scaler params are ignored in the forward computation unless quantization is being used. - # since we are working with the full precision weights as input, don't include these in the parameters being iterated over - layer_params = [ - "self_attn.qkv_proj.weight", # attn_vec_einsum_w - "self_attn.o_proj.weight", # qkv_einsum_w - "mlp.gate_proj.weight", # qkv_einsum_w - "mlp.up_proj.weight", # gating_einsum_w - "mlp.down_proj.weight", # linear_w - "input_layernorm.weight", # pre_attention_norm_scale - "post_attention_layernorm.weight", # pre_ffw_norm_scale + + # note *weight_scaler params are ignored in the forward computation unless + # quantization is being used. + # + # since we are working with the full precision weights as input, don't + # include these in the parameters being iterated over. + + # fmt: off + names = [ + "embedder.weight", # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) + "model.norm.weight" # final_norm_scale (model_dim=2048) ] + layer_params = [ + # TODO(austinvhuang): transpositions here ... + "self_attn.qkv_proj.weight", # attn_vec_einsum_w (2560, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) + "self_attn.o_proj.weight", # qkv_einsum_w (2048, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) + + # these are the same without any change + "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) + "mlp.up_proj.weight", + "mlp.down_proj.weight", # linear_w (2048, 16384) => (model_dim=2048, hidden=16384) + "input_layernorm.weight", # pre_attention_norm_scale (model_dim=2048) + "post_attention_layernorm.weight", # pre_ffw_norm_scale (model_dim=2048) + ] + # fmt: on for layer in range(18): for layer_param in layer_params: - names = names + ["model.layers." + str(layer) + "." + layer_param] + names = names + [f"model.layers.{layer}.{layer_param}"] return names def convert_weights(): - # TODO(austinvhuang): move code in here - pass - - -if __name__ == "__main__": # TODO(austinvhuang): parameterize paths output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") @@ -48,8 +64,12 @@ if __name__ == "__main__": with open(output_file, "wb") as bin_handle: for name in param_order: arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[name](arr) # TODO(austinvhuang): reshapes print(f" {name : <60}{str(arr.shape)}") arr.flatten().astype(np.float32).tofile(bin_handle) + +if __name__ == "__main__": + convert_weights() print("Done")