Fixed 7B conversion.

This commit is contained in:
Phil Culliton 2024-03-12 21:12:28 +00:00
parent 2161908f50
commit b6831a2256
1 changed files with 5 additions and 3 deletions

View File

@ -96,7 +96,7 @@ TRANSFORMATIONS = {
{
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
"self_attn.qkv_proj.weight": expand_qkv,
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]),
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
@ -106,8 +106,8 @@ TRANSFORMATIONS = {
lambda: lambda x: x,
{
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
"self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)),
"self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]),
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
"mlp.down_proj.weight": lambda x: x,
@ -183,6 +183,8 @@ def convert_weights():
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.weights)
model.to(device).eval()
model_dict = dict(model.named_parameters())
param_order = param_names(model_config.num_hidden_layers)