mirror of https://github.com/google/gemma.cpp.git
Fixed 7B conversion.
This commit is contained in:
parent
2161908f50
commit
b6831a2256
|
|
@ -96,7 +96,7 @@ TRANSFORMATIONS = {
|
||||||
{
|
{
|
||||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
||||||
"self_attn.qkv_proj.weight": expand_qkv,
|
"self_attn.qkv_proj.weight": expand_qkv,
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]),
|
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.down_proj.weight": lambda x: x,
|
"mlp.down_proj.weight": lambda x: x,
|
||||||
|
|
@ -106,8 +106,8 @@ TRANSFORMATIONS = {
|
||||||
lambda: lambda x: x,
|
lambda: lambda x: x,
|
||||||
{
|
{
|
||||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
|
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)),
|
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]),
|
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.down_proj.weight": lambda x: x,
|
"mlp.down_proj.weight": lambda x: x,
|
||||||
|
|
@ -183,6 +183,8 @@ def convert_weights():
|
||||||
model = gemma_model.GemmaForCausalLM(model_config)
|
model = gemma_model.GemmaForCausalLM(model_config)
|
||||||
|
|
||||||
model.load_weights(args.weights)
|
model.load_weights(args.weights)
|
||||||
|
model.to(device).eval()
|
||||||
|
|
||||||
model_dict = dict(model.named_parameters())
|
model_dict = dict(model.named_parameters())
|
||||||
param_order = param_names(model_config.num_hidden_layers)
|
param_order = param_names(model_config.num_hidden_layers)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue