diff --git a/util/convert_weights.py b/util/convert_weights.py index a5231b8..c9856d7 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -22,23 +22,35 @@ from gemma import config from gemma import model as gemma_model import numpy as np - def expand_qkv(qkv_proj: np.array) -> np.array: """This won't be needed anymore when MQA is implemented""" + ## this will only be true for 2b assert qkv_proj.shape == (2560, 2048) - qkv_proj.reshape((10, 256, 2048)) - # TODO : repeat dimensions ... - return qkv_proj + qkv = qkv_proj.reshape((10, 256, 2048)) + ## based on line 230 of + ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py + q_proj = qkv[:8].reshape((1,8,256,2048)) + kv_proj = qkv[8:] + kv_proj = kv_proj[:, np.newaxis, :, :] + kv_proj = np.repeat(kv_proj, 8, axis=1) + + qkv = np.concatenate([q_proj, kv_proj]) + qkv = np.transpose(qkv, axes=[1,0,2,3]) + return qkv TRANSFORMATIONS = defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + ## padding goes at end per discussion + "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), "self_attn.qkv_proj.weight": expand_qkv, - "self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? - "mlp.gate_proj.weight": lambda x: x, - "mlp.up_proj.weight": lambda x: x, + + ## based on line 234 of + ## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py + "self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.down_proj.weight": lambda x: x, }, ) @@ -72,9 +84,9 @@ def param_names(): ] layer_params = [ # TODO(austinvhuang): transpositions here ... - # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 - "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) "self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) + # # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 + "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) # these are the same without any change "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) "mlp.up_proj.weight", @@ -86,6 +98,7 @@ def param_names(): for layer in range(18): for layer_param in layer_params: names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] + print("names:", names) return names @@ -94,22 +107,54 @@ def convert_weights(): output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") model_config.dtype = "float32" - model_config.quant = "store_true" + + ## this turns on int8 quantization + # model_config.quant = "store_true" model_config.tokenizer = "models/tokenizer.spm" device = torch.device("cpu") torch.set_default_dtype(torch.float) model = gemma_model.GemmaForCausalLM(model_config) model.load_weights("models/gemma-2b-it.ckpt") model_dict = dict(model.named_parameters()) + + for layer_name in model_dict: + ## Make sure we're not silently having int8 quantization turned on. + print(layer_name, model_dict[layer_name].max()) + assert(model_dict[layer_name].max() > 0.0) + param_order = param_names() - print("Writing parameters ...") - with open(output_file, "wb") as bin_handle: - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[name](arr) - check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" - print(f" {name : <60}{str(arr.shape) : <20}{check}") - arr.flatten().astype(np.float32).tofile(bin_handle) + + all_ok = True + print("Checking transformations ...") + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[layer_name](arr) + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + + if check == "FAILED": + all_ok = False + + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if all_ok: + print("Writing parameters ...") + gate = None + with open(output_file, "wb") as bin_handle: + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[layer_name](arr) + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if "gate_proj" in name: + gate = arr + elif "up_proj" in name: + up = arr + f = np.concatenate([gate, up]) + print (f.shape) + f.flatten().astype(np.float32).tofile(bin_handle) + else: + arr.flatten().astype(np.float32).tofile(bin_handle) if __name__ == "__main__":