mirror of https://github.com/google/gemma.cpp.git
Fixed 7B conversion.
This commit is contained in:
parent
2161908f50
commit
b6831a2256
|
|
@ -96,7 +96,7 @@ TRANSFORMATIONS = {
|
|||
{
|
||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
||||
"self_attn.qkv_proj.weight": expand_qkv,
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
|
|
@ -106,8 +106,8 @@ TRANSFORMATIONS = {
|
|||
lambda: lambda x: x,
|
||||
{
|
||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]),
|
||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
|
||||
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
|
||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||
"mlp.down_proj.weight": lambda x: x,
|
||||
|
|
@ -183,6 +183,8 @@ def convert_weights():
|
|||
model = gemma_model.GemmaForCausalLM(model_config)
|
||||
|
||||
model.load_weights(args.weights)
|
||||
model.to(device).eval()
|
||||
|
||||
model_dict = dict(model.named_parameters())
|
||||
param_order = param_names(model_config.num_hidden_layers)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue