From 5be9a2243f9b4e6c3d3194e12a353e753c281324 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 1 Mar 2024 15:52:51 -0500 Subject: [PATCH 1/7] initial (wip) convert_weights script from pytorch --- util/convert_weights.py | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 util/convert_weights.py diff --git a/util/convert_weights.py b/util/convert_weights.py new file mode 100644 index 0000000..0749cad --- /dev/null +++ b/util/convert_weights.py @@ -0,0 +1,55 @@ +# WIP - DO NOT MERGE + +import torch +from gemma import config +from gemma import model as gemma_model +import numpy as np + + +def param_names(): + """Return parameter names in the order they are expected for deserialization.""" + names = ["embedder.weight", "model.norm.weight"] + # note *weight_scaler params are ignored in the forward computation unless quantization is being used. + # since we are working with the full precision weights as input, don't include these in the parameters being iterated over + layer_params = [ + "self_attn.qkv_proj.weight", # attn_vec_einsum_w + "self_attn.o_proj.weight", # qkv_einsum_w + "mlp.gate_proj.weight", # qkv_einsum_w + "mlp.up_proj.weight", # gating_einsum_w + "mlp.down_proj.weight", # linear_w + "input_layernorm.weight", # pre_attention_norm_scale + "post_attention_layernorm.weight", # pre_ffw_norm_scale + ] + for layer in range(18): + for layer_param in layer_params: + names = names + ["model.layers." + str(layer) + "." + layer_param] + return names + + +def convert_weights(): + # TODO(austinvhuang): move code in here + pass + + +if __name__ == "__main__": + # TODO(austinvhuang): parameterize paths + output_file = "2bit-f32.sbs" + model_config = config.get_model_config("2b") + model_config.dtype = "float32" + model_config.quant = "store_true" + model_config.tokenizer = "models/tokenizer.spm" + device = torch.device("cpu") + torch.set_default_dtype(torch.float) + model = gemma_model.GemmaForCausalLM(model_config) + model.load_weights("models/gemma-2b-it.ckpt") + model_dict = dict(model.named_parameters()) + param_order = param_names() + print("Writing parameters ...") + with open(output_file, "wb") as bin_handle: + for name in param_order: + arr = model_dict[name].detach().numpy() + # TODO(austinvhuang): reshapes + print(f" {name : <60}{str(arr.shape)}") + arr.flatten().astype(np.float32).tofile(bin_handle) + + print("Done") From 7d7d43e661ed063b071f473747c2960cf7cc18bd Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sat, 2 Mar 2024 08:11:55 -0500 Subject: [PATCH 2/7] converter transformations (wip) --- util/convert_weights.py | 54 ++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index 0749cad..2e96edd 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -1,37 +1,53 @@ # WIP - DO NOT MERGE +from collections import defaultdict import torch from gemma import config from gemma import model as gemma_model import numpy as np +TRANSFORMATIONS = defaultdict(lambda: lambda x: x, { + "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + "self_attn.qkv_proj.weight": lambda x: x, + "mlp.up_proj.weight" : lambda x: x, + "mlp.down_proj.weight" : lambda x: x, +}) + def param_names(): """Return parameter names in the order they are expected for deserialization.""" - names = ["embedder.weight", "model.norm.weight"] - # note *weight_scaler params are ignored in the forward computation unless quantization is being used. - # since we are working with the full precision weights as input, don't include these in the parameters being iterated over - layer_params = [ - "self_attn.qkv_proj.weight", # attn_vec_einsum_w - "self_attn.o_proj.weight", # qkv_einsum_w - "mlp.gate_proj.weight", # qkv_einsum_w - "mlp.up_proj.weight", # gating_einsum_w - "mlp.down_proj.weight", # linear_w - "input_layernorm.weight", # pre_attention_norm_scale - "post_attention_layernorm.weight", # pre_ffw_norm_scale + + # note *weight_scaler params are ignored in the forward computation unless + # quantization is being used. + # + # since we are working with the full precision weights as input, don't + # include these in the parameters being iterated over. + + # fmt: off + names = [ + "embedder.weight", # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) + "model.norm.weight" # final_norm_scale (model_dim=2048) ] + layer_params = [ + # TODO(austinvhuang): transpositions here ... + "self_attn.qkv_proj.weight", # attn_vec_einsum_w (2560, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) + "self_attn.o_proj.weight", # qkv_einsum_w (2048, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) + + # these are the same without any change + "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) + "mlp.up_proj.weight", + "mlp.down_proj.weight", # linear_w (2048, 16384) => (model_dim=2048, hidden=16384) + "input_layernorm.weight", # pre_attention_norm_scale (model_dim=2048) + "post_attention_layernorm.weight", # pre_ffw_norm_scale (model_dim=2048) + ] + # fmt: on for layer in range(18): for layer_param in layer_params: - names = names + ["model.layers." + str(layer) + "." + layer_param] + names = names + [f"model.layers.{layer}.{layer_param}"] return names def convert_weights(): - # TODO(austinvhuang): move code in here - pass - - -if __name__ == "__main__": # TODO(austinvhuang): parameterize paths output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") @@ -48,8 +64,12 @@ if __name__ == "__main__": with open(output_file, "wb") as bin_handle: for name in param_order: arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[name](arr) # TODO(austinvhuang): reshapes print(f" {name : <60}{str(arr.shape)}") arr.flatten().astype(np.float32).tofile(bin_handle) + +if __name__ == "__main__": + convert_weights() print("Done") From 3c69695c1e49b958adfb6e4f567a5080c01cc39f Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sat, 2 Mar 2024 14:46:51 -0500 Subject: [PATCH 3/7] transformations and validations (wip) --- util/convert_weights.py | 74 ++++++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index 2e96edd..a5231b8 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -1,4 +1,20 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # WIP - DO NOT MERGE +# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch from collections import defaultdict import torch @@ -6,12 +22,38 @@ from gemma import config from gemma import model as gemma_model import numpy as np -TRANSFORMATIONS = defaultdict(lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), - "self_attn.qkv_proj.weight": lambda x: x, - "mlp.up_proj.weight" : lambda x: x, - "mlp.down_proj.weight" : lambda x: x, -}) + +def expand_qkv(qkv_proj: np.array) -> np.array: + """This won't be needed anymore when MQA is implemented""" + assert qkv_proj.shape == (2560, 2048) + qkv_proj.reshape((10, 256, 2048)) + # TODO : repeat dimensions ... + return qkv_proj + + +TRANSFORMATIONS = defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + "self_attn.qkv_proj.weight": expand_qkv, + "self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? + "mlp.gate_proj.weight": lambda x: x, + "mlp.up_proj.weight": lambda x: x, + "mlp.down_proj.weight": lambda x: x, + }, +) + +VALIDATIONS = { + "embedder.weight": lambda x: x.shape == (256128, 2048), + "model.norm.weight": lambda x: x.shape == (2048,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), + "input_layernorm.weight": lambda x: x.shape == (2048,), + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), +} def param_names(): @@ -25,14 +67,14 @@ def param_names(): # fmt: off names = [ - "embedder.weight", # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) - "model.norm.weight" # final_norm_scale (model_dim=2048) + ("embedder.weight", ) * 2, # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) + ("model.norm.weight", ) * 2 # final_norm_scale (model_dim=2048) ] layer_params = [ # TODO(austinvhuang): transpositions here ... - "self_attn.qkv_proj.weight", # attn_vec_einsum_w (2560, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) - "self_attn.o_proj.weight", # qkv_einsum_w (2048, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) - + # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 + "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) + "self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) # these are the same without any change "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) "mlp.up_proj.weight", @@ -43,12 +85,12 @@ def param_names(): # fmt: on for layer in range(18): for layer_param in layer_params: - names = names + [f"model.layers.{layer}.{layer_param}"] + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] return names def convert_weights(): - # TODO(austinvhuang): parameterize paths + # TODO: parameterize paths as CLI args instead of hard coding them output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") model_config.dtype = "float32" @@ -62,11 +104,11 @@ def convert_weights(): param_order = param_names() print("Writing parameters ...") with open(output_file, "wb") as bin_handle: - for name in param_order: + for name, layer_name in param_order: arr = model_dict[name].detach().numpy() arr = TRANSFORMATIONS[name](arr) - # TODO(austinvhuang): reshapes - print(f" {name : <60}{str(arr.shape)}") + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") arr.flatten().astype(np.float32).tofile(bin_handle) From c93e1a1e4d815b5b0b43f17c8eef100a7ca20bf8 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 5 Mar 2024 17:54:55 +0000 Subject: [PATCH 4/7] Resolved layer ordering, reshaping, MQA->MHA, and quantization. Works only for 2B. --- util/convert_weights.py | 83 +++++++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index a5231b8..c9856d7 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -22,23 +22,35 @@ from gemma import config from gemma import model as gemma_model import numpy as np - def expand_qkv(qkv_proj: np.array) -> np.array: """This won't be needed anymore when MQA is implemented""" + ## this will only be true for 2b assert qkv_proj.shape == (2560, 2048) - qkv_proj.reshape((10, 256, 2048)) - # TODO : repeat dimensions ... - return qkv_proj + qkv = qkv_proj.reshape((10, 256, 2048)) + ## based on line 230 of + ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py + q_proj = qkv[:8].reshape((1,8,256,2048)) + kv_proj = qkv[8:] + kv_proj = kv_proj[:, np.newaxis, :, :] + kv_proj = np.repeat(kv_proj, 8, axis=1) + + qkv = np.concatenate([q_proj, kv_proj]) + qkv = np.transpose(qkv, axes=[1,0,2,3]) + return qkv TRANSFORMATIONS = defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + ## padding goes at end per discussion + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, - "self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? - "mlp.gate_proj.weight": lambda x: x, - "mlp.up_proj.weight": lambda x: x, + + ## based on line 234 of + ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py + "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, }, ) @@ -72,9 +84,9 @@ def param_names(): ] layer_params = [ # TODO(austinvhuang): transpositions here ... - # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 - "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) "self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) + # # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 + "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) # these are the same without any change "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) "mlp.up_proj.weight", @@ -86,6 +98,7 @@ def param_names(): for layer in range(18): for layer_param in layer_params: names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] + print("names:", names) return names @@ -94,22 +107,54 @@ def convert_weights(): output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") model_config.dtype = "float32" - model_config.quant = "store_true" + + ## this turns on int8 quantization + # model_config.quant = "store_true" model_config.tokenizer = "models/tokenizer.spm" device = torch.device("cpu") torch.set_default_dtype(torch.float) model = gemma_model.GemmaForCausalLM(model_config) model.load_weights("models/gemma-2b-it.ckpt") model_dict = dict(model.named_parameters()) + + for layer_name in model_dict: + ## Make sure we're not silently having int8 quantization turned on. + print(layer_name, model_dict[layer_name].max()) + assert(model_dict[layer_name].max() > 0.0) + param_order = param_names() - print("Writing parameters ...") - with open(output_file, "wb") as bin_handle: - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[name](arr) - check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" - print(f" {name : <60}{str(arr.shape) : <20}{check}") - arr.flatten().astype(np.float32).tofile(bin_handle) + + all_ok = True + print("Checking transformations ...") + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[layer_name](arr) + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + + if check == "FAILED": + all_ok = False + + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if all_ok: + print("Writing parameters ...") + gate = None + with open(output_file, "wb") as bin_handle: + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[layer_name](arr) + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if "gate_proj" in name: + gate = arr + elif "up_proj" in name: + up = arr + f = np.concatenate([gate, up]) + print (f.shape) + f.flatten().astype(np.float32).tofile(bin_handle) + else: + arr.flatten().astype(np.float32).tofile(bin_handle) if __name__ == "__main__": From 2161908f50b4ac5f773643e85dc92a4092ad8b49 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Thu, 7 Mar 2024 22:34:14 +0000 Subject: [PATCH 5/7] Added 7B support and args parsing. Still todo: more testing of 7B conversion. --- util/convert_weights.py | 168 ++++++++++++++++++++++++++-------------- 1 file changed, 110 insertions(+), 58 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index c9856d7..d9cdc15 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -13,23 +13,74 @@ # See the License for the specific language governing permissions and # limitations under the License. -# WIP - DO NOT MERGE -# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch from collections import defaultdict import torch from gemma import config from gemma import model as gemma_model import numpy as np +import argparse +import os + +# WIP - DO NOT MERGE +# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch + +## parameters +## model, tokenizer, model type, + +def check_file_exists(value): + if not os.path.exists(str(value)): + raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) + return value + + +def check_model_types(value): + if str(value).lower() not in ["2b", "7b"]: + raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) + return value + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--tokenizer", + dest="tokenizer", + default="models/tokenizer.spm", + help="Location of tokenizer file (.model or .spm)", + type=check_file_exists, +) + +parser.add_argument( + "--weights", + dest="weights", + default="models/gemma-2b-it.ckpt", + help="Location of input checkpoint file (.ckpt)", + type=check_file_exists, +) + +parser.add_argument( + "--output_file", + dest="output_file", + default="2bit-f32.sbs", + help="Location to write converted weights", + type=str, +) + +parser.add_argument( + "--model_type", + dest="model_type", + default="2b", + help="Model size / type (2b, 7b)", + type=check_model_types, +) + +args = parser.parse_args() + def expand_qkv(qkv_proj: np.array) -> np.array: """This won't be needed anymore when MQA is implemented""" - ## this will only be true for 2b assert qkv_proj.shape == (2560, 2048) qkv = qkv_proj.reshape((10, 256, 2048)) - ## based on line 230 of - ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py q_proj = qkv[:8].reshape((1,8,256,2048)) kv_proj = qkv[8:] kv_proj = kv_proj[:, np.newaxis, :, :] @@ -39,23 +90,33 @@ def expand_qkv(qkv_proj: np.array) -> np.array: qkv = np.transpose(qkv, axes=[1,0,2,3]) return qkv -TRANSFORMATIONS = defaultdict( +TRANSFORMATIONS = { + "2b":defaultdict( lambda: lambda x: x, { - ## padding goes at end per discussion "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, - - ## based on line 234 of - ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py - "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? + "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, - }, -) + } + ), + "7b":defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), + "self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)), + "self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]), + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.down_proj.weight": lambda x: x, + } + ), +} VALIDATIONS = { + "2b": { "embedder.weight": lambda x: x.shape == (256128, 2048), "model.norm.weight": lambda x: x.shape == (2048,), "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), @@ -65,10 +126,22 @@ VALIDATIONS = { "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), "input_layernorm.weight": lambda x: x.shape == (2048,), "post_attention_layernorm.weight": lambda x: x.shape == (2048,), + }, + "7b": { + "embedder.weight": lambda x: x.shape == (256128, 3072), + "model.norm.weight": lambda x: x.shape == (3072,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), + "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), + "input_layernorm.weight": lambda x: x.shape == (3072,), + "post_attention_layernorm.weight": lambda x: x.shape == (3072,), + }, } -def param_names(): +def param_names(num_hidden_layers: int): """Return parameter names in the order they are expected for deserialization.""" # note *weight_scaler params are ignored in the forward computation unless @@ -79,62 +152,50 @@ def param_names(): # fmt: off names = [ - ("embedder.weight", ) * 2, # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) - ("model.norm.weight", ) * 2 # final_norm_scale (model_dim=2048) + ("embedder.weight", ) * 2, # embedder_input_embedding + ("model.norm.weight", ) * 2 # final_norm_scale ] layer_params = [ - # TODO(austinvhuang): transpositions here ... - "self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) - # # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 - "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) - # these are the same without any change - "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) + "self_attn.o_proj.weight", # attn_vec_einsum_w + "self_attn.qkv_proj.weight", # qkv_einsum_w + "mlp.gate_proj.weight", # gating_einsum_w "mlp.up_proj.weight", - "mlp.down_proj.weight", # linear_w (2048, 16384) => (model_dim=2048, hidden=16384) - "input_layernorm.weight", # pre_attention_norm_scale (model_dim=2048) - "post_attention_layernorm.weight", # pre_ffw_norm_scale (model_dim=2048) + "mlp.down_proj.weight", # linear_w + "input_layernorm.weight", # pre_attention_norm_scale + "post_attention_layernorm.weight", # pre_ffw_norm_scale ] # fmt: on - for layer in range(18): + for layer in range(num_hidden_layers): for layer_param in layer_params: names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] - print("names:", names) return names def convert_weights(): - # TODO: parameterize paths as CLI args instead of hard coding them - output_file = "2bit-f32.sbs" - model_config = config.get_model_config("2b") + model_type = args.model_type + output_file = args.output_file + + model_config = config.get_model_config(model_type) model_config.dtype = "float32" - - ## this turns on int8 quantization - # model_config.quant = "store_true" - model_config.tokenizer = "models/tokenizer.spm" + model_config.tokenizer = args.tokenizer device = torch.device("cpu") torch.set_default_dtype(torch.float) model = gemma_model.GemmaForCausalLM(model_config) - model.load_weights("models/gemma-2b-it.ckpt") - model_dict = dict(model.named_parameters()) - - for layer_name in model_dict: - ## Make sure we're not silently having int8 quantization turned on. - print(layer_name, model_dict[layer_name].max()) - assert(model_dict[layer_name].max() > 0.0) - param_order = param_names() + model.load_weights(args.weights) + model_dict = dict(model.named_parameters()) + param_order = param_names(model_config.num_hidden_layers) all_ok = True print("Checking transformations ...") for name, layer_name in param_order: arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[layer_name](arr) - check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" if check == "FAILED": all_ok = False - - print(f" {name : <60}{str(arr.shape) : <20}{check}") + print(f" {name : <60}{str(arr.shape) : <20}{check}") if all_ok: print("Writing parameters ...") @@ -142,19 +203,10 @@ def convert_weights(): with open(output_file, "wb") as bin_handle: for name, layer_name in param_order: arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[layer_name](arr) - check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" print(f" {name : <60}{str(arr.shape) : <20}{check}") - - if "gate_proj" in name: - gate = arr - elif "up_proj" in name: - up = arr - f = np.concatenate([gate, up]) - print (f.shape) - f.flatten().astype(np.float32).tofile(bin_handle) - else: - arr.flatten().astype(np.float32).tofile(bin_handle) + arr.flatten().astype(np.float32).tofile(bin_handle) if __name__ == "__main__": From b6831a2256ac4a7c73b945aee2375af8fea8d234 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 12 Mar 2024 21:12:28 +0000 Subject: [PATCH 6/7] Fixed 7B conversion. --- util/convert_weights.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index d9cdc15..187cd2f 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -96,7 +96,7 @@ TRANSFORMATIONS = { { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, - "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), + "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, @@ -106,8 +106,8 @@ TRANSFORMATIONS = { lambda: lambda x: x, { "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), - "self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)), - "self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]), + "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), + "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, @@ -183,6 +183,8 @@ def convert_weights(): model = gemma_model.GemmaForCausalLM(model_config) model.load_weights(args.weights) + model.to(device).eval() + model_dict = dict(model.named_parameters()) param_order = param_names(model_config.num_hidden_layers) From f520e5c25c47fb5bf7e9c6289a11da51fc240531 Mon Sep 17 00:00:00 2001 From: pculliton Date: Wed, 13 Mar 2024 11:36:19 -0400 Subject: [PATCH 7/7] Remove WIP messages. --- util/convert_weights.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/util/convert_weights.py b/util/convert_weights.py index 187cd2f..bd6750a 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -22,12 +22,8 @@ import numpy as np import argparse import os -# WIP - DO NOT MERGE # Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch -## parameters -## model, tokenizer, model type, - def check_file_exists(value): if not os.path.exists(str(value)): raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)