diff --git a/util/convert_weights.py b/util/convert_weights.py index 2e96edd..a5231b8 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -1,4 +1,20 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # WIP - DO NOT MERGE +# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch from collections import defaultdict import torch @@ -6,12 +22,38 @@ from gemma import config from gemma import model as gemma_model import numpy as np -TRANSFORMATIONS = defaultdict(lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), - "self_attn.qkv_proj.weight": lambda x: x, - "mlp.up_proj.weight" : lambda x: x, - "mlp.down_proj.weight" : lambda x: x, -}) + +def expand_qkv(qkv_proj: np.array) -> np.array: + """This won't be needed anymore when MQA is implemented""" + assert qkv_proj.shape == (2560, 2048) + qkv_proj.reshape((10, 256, 2048)) + # TODO : repeat dimensions ... + return qkv_proj + + +TRANSFORMATIONS = defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0), + "self_attn.qkv_proj.weight": expand_qkv, + "self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim? + "mlp.gate_proj.weight": lambda x: x, + "mlp.up_proj.weight": lambda x: x, + "mlp.down_proj.weight": lambda x: x, + }, +) + +VALIDATIONS = { + "embedder.weight": lambda x: x.shape == (256128, 2048), + "model.norm.weight": lambda x: x.shape == (2048,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), + "input_layernorm.weight": lambda x: x.shape == (2048,), + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), +} def param_names(): @@ -25,14 +67,14 @@ def param_names(): # fmt: off names = [ - "embedder.weight", # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) - "model.norm.weight" # final_norm_scale (model_dim=2048) + ("embedder.weight", ) * 2, # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048) + ("model.norm.weight", ) * 2 # final_norm_scale (model_dim=2048) ] layer_params = [ # TODO(austinvhuang): transpositions here ... - "self_attn.qkv_proj.weight", # attn_vec_einsum_w (2560, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) - "self_attn.o_proj.weight", # qkv_einsum_w (2048, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) - + # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560 + "self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048) + "self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256) # these are the same without any change "mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048) "mlp.up_proj.weight", @@ -43,12 +85,12 @@ def param_names(): # fmt: on for layer in range(18): for layer_param in layer_params: - names = names + [f"model.layers.{layer}.{layer_param}"] + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] return names def convert_weights(): - # TODO(austinvhuang): parameterize paths + # TODO: parameterize paths as CLI args instead of hard coding them output_file = "2bit-f32.sbs" model_config = config.get_model_config("2b") model_config.dtype = "float32" @@ -62,11 +104,11 @@ def convert_weights(): param_order = param_names() print("Writing parameters ...") with open(output_file, "wb") as bin_handle: - for name in param_order: + for name, layer_name in param_order: arr = model_dict[name].detach().numpy() arr = TRANSFORMATIONS[name](arr) - # TODO(austinvhuang): reshapes - print(f" {name : <60}{str(arr.shape)}") + check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") arr.flatten().astype(np.float32).tofile(bin_handle)