mirror of https://github.com/google/gemma.cpp.git
transformations and validations (wip)
This commit is contained in:
parent
7d7d43e661
commit
3c69695c1e
|
|
@ -1,4 +1,20 @@
|
||||||
|
# Copyright 2024 Google LLC
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# WIP - DO NOT MERGE
|
# WIP - DO NOT MERGE
|
||||||
|
# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -6,12 +22,38 @@ from gemma import config
|
||||||
from gemma import model as gemma_model
|
from gemma import model as gemma_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
TRANSFORMATIONS = defaultdict(lambda: lambda x: x, {
|
|
||||||
"embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0),
|
def expand_qkv(qkv_proj: np.array) -> np.array:
|
||||||
"self_attn.qkv_proj.weight": lambda x: x,
|
"""This won't be needed anymore when MQA is implemented"""
|
||||||
"mlp.up_proj.weight" : lambda x: x,
|
assert qkv_proj.shape == (2560, 2048)
|
||||||
"mlp.down_proj.weight" : lambda x: x,
|
qkv_proj.reshape((10, 256, 2048))
|
||||||
})
|
# TODO : repeat dimensions ...
|
||||||
|
return qkv_proj
|
||||||
|
|
||||||
|
|
||||||
|
TRANSFORMATIONS = defaultdict(
|
||||||
|
lambda: lambda x: x,
|
||||||
|
{
|
||||||
|
"embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0),
|
||||||
|
"self_attn.qkv_proj.weight": expand_qkv,
|
||||||
|
"self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim?
|
||||||
|
"mlp.gate_proj.weight": lambda x: x,
|
||||||
|
"mlp.up_proj.weight": lambda x: x,
|
||||||
|
"mlp.down_proj.weight": lambda x: x,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
VALIDATIONS = {
|
||||||
|
"embedder.weight": lambda x: x.shape == (256128, 2048),
|
||||||
|
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||||
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
||||||
|
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
|
||||||
|
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||||
|
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
|
||||||
|
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
|
||||||
|
"input_layernorm.weight": lambda x: x.shape == (2048,),
|
||||||
|
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def param_names():
|
def param_names():
|
||||||
|
|
@ -25,14 +67,14 @@ def param_names():
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
names = [
|
names = [
|
||||||
"embedder.weight", # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048)
|
("embedder.weight", ) * 2, # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048)
|
||||||
"model.norm.weight" # final_norm_scale (model_dim=2048)
|
("model.norm.weight", ) * 2 # final_norm_scale (model_dim=2048)
|
||||||
]
|
]
|
||||||
layer_params = [
|
layer_params = [
|
||||||
# TODO(austinvhuang): transpositions here ...
|
# TODO(austinvhuang): transpositions here ...
|
||||||
"self_attn.qkv_proj.weight", # attn_vec_einsum_w (2560, 2048) -> (heads=8, model_dim=2048, qkv_dim=256)
|
# ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560
|
||||||
"self_attn.o_proj.weight", # qkv_einsum_w (2048, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048)
|
"self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048)
|
||||||
|
"self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256)
|
||||||
# these are the same without any change
|
# these are the same without any change
|
||||||
"mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048)
|
"mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048)
|
||||||
"mlp.up_proj.weight",
|
"mlp.up_proj.weight",
|
||||||
|
|
@ -43,12 +85,12 @@ def param_names():
|
||||||
# fmt: on
|
# fmt: on
|
||||||
for layer in range(18):
|
for layer in range(18):
|
||||||
for layer_param in layer_params:
|
for layer_param in layer_params:
|
||||||
names = names + [f"model.layers.{layer}.{layer_param}"]
|
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
def convert_weights():
|
def convert_weights():
|
||||||
# TODO(austinvhuang): parameterize paths
|
# TODO: parameterize paths as CLI args instead of hard coding them
|
||||||
output_file = "2bit-f32.sbs"
|
output_file = "2bit-f32.sbs"
|
||||||
model_config = config.get_model_config("2b")
|
model_config = config.get_model_config("2b")
|
||||||
model_config.dtype = "float32"
|
model_config.dtype = "float32"
|
||||||
|
|
@ -62,11 +104,11 @@ def convert_weights():
|
||||||
param_order = param_names()
|
param_order = param_names()
|
||||||
print("Writing parameters ...")
|
print("Writing parameters ...")
|
||||||
with open(output_file, "wb") as bin_handle:
|
with open(output_file, "wb") as bin_handle:
|
||||||
for name in param_order:
|
for name, layer_name in param_order:
|
||||||
arr = model_dict[name].detach().numpy()
|
arr = model_dict[name].detach().numpy()
|
||||||
arr = TRANSFORMATIONS[name](arr)
|
arr = TRANSFORMATIONS[name](arr)
|
||||||
# TODO(austinvhuang): reshapes
|
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
||||||
print(f" {name : <60}{str(arr.shape)}")
|
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue