From b6831a2256ac4a7c73b945aee2375af8fea8d234 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 12 Mar 2024 21:12:28 +0000 Subject: [PATCH] Fixed 7B conversion. --- util/convert_weights.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index d9cdc15..187cd2f 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -96,7 +96,7 @@ TRANSFORMATIONS = { { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, - "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), + "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, @@ -106,8 +106,8 @@ TRANSFORMATIONS = { lambda: lambda x: x, { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), - "self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)), - "self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]), + "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), + "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, @@ -183,6 +183,8 @@ def convert_weights(): model = gemma_model.GemmaForCausalLM(model_config) model.load_weights(args.weights) + model.to(device).eval() + model_dict = dict(model.named_parameters()) param_order = param_names(model_config.num_hidden_layers)