mirror of https://github.com/google/gemma.cpp.git
Resolved layer ordering, reshaping, MQA->MHA, and quantization. Works only for 2B.
This commit is contained in:
parent
3c69695c1e
commit
c93e1a1e4d
|
|
@ -22,23 +22,35 @@ from gemma import config
|
|||
from gemma import model as gemma_model
|
||||
import numpy as np
|
||||
|
||||
|
||||
def expand_qkv(qkv_proj: np.array) -> np.array:
|
||||
"""This won't be needed anymore when MQA is implemented"""
|
||||
## this will only be true for 2b
|
||||
assert qkv_proj.shape == (2560, 2048)
|
||||
qkv_proj.reshape((10, 256, 2048))
|
||||
# TODO : repeat dimensions ...
|
||||
return qkv_proj
|
||||
qkv = qkv_proj.reshape((10, 256, 2048))
|
||||
|
||||
## based on line 230 of
|
||||
## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py
|
||||
q_proj = qkv[:8].reshape((1,8,256,2048))
|
||||
kv_proj = qkv[8:]
|
||||
kv_proj = kv_proj[:, np.newaxis, :, :]
|
||||
kv_proj = np.repeat(kv_proj, 8, axis=1)
|
||||
|
||||
qkv = np.concatenate([q_proj, kv_proj])
|
||||
qkv = np.transpose(qkv, axes=[1,0,2,3])
|
||||
return qkv
|
||||
|
||||
TRANSFORMATIONS = defaultdict(
|
||||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: np.concatenate([np.zeros([128, 2048]), x], 0),
|
||||
## padding goes at end per discussion
|
||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
||||
"self_attn.qkv_proj.weight": expand_qkv,
|
||||
"self_attn.o_proj.weight": lambda x: x, # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim?
|
||||
"mlp.gate_proj.weight": lambda x: x,
|
||||
"mlp.up_proj.weight": lambda x: x,
|
||||
|
||||
## based on line 234 of
|
||||
## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim?
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
},
|
||||
)
|
||||
|
|
@ -72,9 +84,9 @@ def param_names():
|
|||
]
|
||||
layer_params = [
|
||||
# TODO(austinvhuang): transpositions here ...
|
||||
# ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560
|
||||
"self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048)
|
||||
"self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256)
|
||||
# # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560
|
||||
"self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048)
|
||||
# these are the same without any change
|
||||
"mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048)
|
||||
"mlp.up_proj.weight",
|
||||
|
|
@ -86,6 +98,7 @@ def param_names():
|
|||
for layer in range(18):
|
||||
for layer_param in layer_params:
|
||||
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
||||
print("names:", names)
|
||||
return names
|
||||
|
||||
|
||||
|
|
@ -94,22 +107,54 @@ def convert_weights():
|
|||
output_file = "2bit-f32.sbs"
|
||||
model_config = config.get_model_config("2b")
|
||||
model_config.dtype = "float32"
|
||||
model_config.quant = "store_true"
|
||||
|
||||
## this turns on int8 quantization
|
||||
# model_config.quant = "store_true"
|
||||
model_config.tokenizer = "models/tokenizer.spm"
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
model.load_weights("models/gemma-2b-it.ckpt")
|
||||
model_dict = dict(model.named_parameters())
|
||||
|
||||
for layer_name in model_dict:
|
||||
## Make sure we're not silently having int8 quantization turned on.
|
||||
print(layer_name, model_dict[layer_name].max())
|
||||
assert(model_dict[layer_name].max() > 0.0)
|
||||
|
||||
param_order = param_names()
|
||||
print("Writing parameters ...")
|
||||
with open(output_file, "wb") as bin_handle:
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[name](arr)
|
||||
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||
|
||||
all_ok = True
|
||||
print("Checking transformations ...")
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
||||
|
||||
if check == "FAILED":
|
||||
all_ok = False
|
||||
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
|
||||
if all_ok:
|
||||
print("Writing parameters ...")
|
||||
gate = None
|
||||
with open(output_file, "wb") as bin_handle:
|
||||
for name, layer_name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
arr = TRANSFORMATIONS[layer_name](arr)
|
||||
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||
|
||||
if "gate_proj" in name:
|
||||
gate = arr
|
||||
elif "up_proj" in name:
|
||||
up = arr
|
||||
f = np.concatenate([gate, up])
|
||||
print (f.shape)
|
||||
f.flatten().astype(np.float32).tofile(bin_handle)
|
||||
else:
|
||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue