mirror of https://github.com/google/gemma.cpp.git
initial (wip) convert_weights script from pytorch
This commit is contained in:
parent
808dbdc42b
commit
5be9a2243f
|
|
@ -0,0 +1,55 @@
|
|||
# WIP - DO NOT MERGE
|
||||
|
||||
import torch
|
||||
from gemma import config
|
||||
from gemma import model as gemma_model
|
||||
import numpy as np
|
||||
|
||||
|
||||
def param_names():
|
||||
"""Return parameter names in the order they are expected for deserialization."""
|
||||
names = ["embedder.weight", "model.norm.weight"]
|
||||
# note *weight_scaler params are ignored in the forward computation unless quantization is being used.
|
||||
# since we are working with the full precision weights as input, don't include these in the parameters being iterated over
|
||||
layer_params = [
|
||||
"self_attn.qkv_proj.weight", # attn_vec_einsum_w
|
||||
"self_attn.o_proj.weight", # qkv_einsum_w
|
||||
"mlp.gate_proj.weight", # qkv_einsum_w
|
||||
"mlp.up_proj.weight", # gating_einsum_w
|
||||
"mlp.down_proj.weight", # linear_w
|
||||
"input_layernorm.weight", # pre_attention_norm_scale
|
||||
"post_attention_layernorm.weight", # pre_ffw_norm_scale
|
||||
]
|
||||
for layer in range(18):
|
||||
for layer_param in layer_params:
|
||||
names = names + ["model.layers." + str(layer) + "." + layer_param]
|
||||
return names
|
||||
|
||||
|
||||
def convert_weights():
|
||||
# TODO(austinvhuang): move code in here
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(austinvhuang): parameterize paths
|
||||
output_file = "2bit-f32.sbs"
|
||||
model_config = config.get_model_config("2b")
|
||||
model_config.dtype = "float32"
|
||||
model_config.quant = "store_true"
|
||||
model_config.tokenizer = "models/tokenizer.spm"
|
||||
device = torch.device("cpu")
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
model.load_weights("models/gemma-2b-it.ckpt")
|
||||
model_dict = dict(model.named_parameters())
|
||||
param_order = param_names()
|
||||
print("Writing parameters ...")
|
||||
with open(output_file, "wb") as bin_handle:
|
||||
for name in param_order:
|
||||
arr = model_dict[name].detach().numpy()
|
||||
# TODO(austinvhuang): reshapes
|
||||
print(f" {name : <60}{str(arr.shape)}")
|
||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||
|
||||
print("Done")
|
||||
Loading…
Reference in New Issue