mirror of https://github.com/google/gemma.cpp.git
Added 7B support and args parsing. Still todo: more testing of 7B conversion.
This commit is contained in:
parent
c93e1a1e4d
commit
2161908f50
|
|
@ -13,23 +13,74 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# WIP - DO NOT MERGE
|
|
||||||
# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import torch
|
import torch
|
||||||
from gemma import config
|
from gemma import config
|
||||||
from gemma import model as gemma_model
|
from gemma import model as gemma_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
# WIP - DO NOT MERGE
|
||||||
|
# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch
|
||||||
|
|
||||||
|
## parameters
|
||||||
|
## model, tokenizer, model type,
|
||||||
|
|
||||||
|
def check_file_exists(value):
|
||||||
|
if not os.path.exists(str(value)):
|
||||||
|
raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_types(value):
|
||||||
|
if str(value).lower() not in ["2b", "7b"]:
|
||||||
|
raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer",
|
||||||
|
dest="tokenizer",
|
||||||
|
default="models/tokenizer.spm",
|
||||||
|
help="Location of tokenizer file (.model or .spm)",
|
||||||
|
type=check_file_exists,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--weights",
|
||||||
|
dest="weights",
|
||||||
|
default="models/gemma-2b-it.ckpt",
|
||||||
|
help="Location of input checkpoint file (.ckpt)",
|
||||||
|
type=check_file_exists,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file",
|
||||||
|
dest="output_file",
|
||||||
|
default="2bit-f32.sbs",
|
||||||
|
help="Location to write converted weights",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
dest="model_type",
|
||||||
|
default="2b",
|
||||||
|
help="Model size / type (2b, 7b)",
|
||||||
|
type=check_model_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def expand_qkv(qkv_proj: np.array) -> np.array:
|
def expand_qkv(qkv_proj: np.array) -> np.array:
|
||||||
"""This won't be needed anymore when MQA is implemented"""
|
"""This won't be needed anymore when MQA is implemented"""
|
||||||
## this will only be true for 2b
|
|
||||||
assert qkv_proj.shape == (2560, 2048)
|
assert qkv_proj.shape == (2560, 2048)
|
||||||
qkv = qkv_proj.reshape((10, 256, 2048))
|
qkv = qkv_proj.reshape((10, 256, 2048))
|
||||||
|
|
||||||
## based on line 230 of
|
|
||||||
## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py
|
|
||||||
q_proj = qkv[:8].reshape((1,8,256,2048))
|
q_proj = qkv[:8].reshape((1,8,256,2048))
|
||||||
kv_proj = qkv[8:]
|
kv_proj = qkv[8:]
|
||||||
kv_proj = kv_proj[:, np.newaxis, :, :]
|
kv_proj = kv_proj[:, np.newaxis, :, :]
|
||||||
|
|
@ -39,23 +90,33 @@ def expand_qkv(qkv_proj: np.array) -> np.array:
|
||||||
qkv = np.transpose(qkv, axes=[1,0,2,3])
|
qkv = np.transpose(qkv, axes=[1,0,2,3])
|
||||||
return qkv
|
return qkv
|
||||||
|
|
||||||
TRANSFORMATIONS = defaultdict(
|
TRANSFORMATIONS = {
|
||||||
|
"2b":defaultdict(
|
||||||
lambda: lambda x: x,
|
lambda: lambda x: x,
|
||||||
{
|
{
|
||||||
## padding goes at end per discussion
|
|
||||||
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0),
|
||||||
"self_attn.qkv_proj.weight": expand_qkv,
|
"self_attn.qkv_proj.weight": expand_qkv,
|
||||||
|
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]),
|
||||||
## based on line 234 of
|
|
||||||
## https://github.com/google/gemma_pytorch/blob/main/gemma/model.py
|
|
||||||
"self_attn.o_proj.weight": lambda x: x.reshape(2048, 8, 256).transpose([1,0,2]), # TODO: which of the 2048 is unpacked to 8 x 256, and which is model_dim?
|
|
||||||
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
"mlp.down_proj.weight": lambda x: x,
|
"mlp.down_proj.weight": lambda x: x,
|
||||||
},
|
}
|
||||||
)
|
),
|
||||||
|
"7b":defaultdict(
|
||||||
|
lambda: lambda x: x,
|
||||||
|
{
|
||||||
|
"embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0),
|
||||||
|
"self_attn.qkv_proj.weight": lambda x: x.reshape((16, 3, 256, 3072)),
|
||||||
|
"self_attn.o_proj.weight": lambda x: x.reshape(3072, 16, 256).transpose([1,0,2]),
|
||||||
|
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
|
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
|
||||||
|
"mlp.down_proj.weight": lambda x: x,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
VALIDATIONS = {
|
VALIDATIONS = {
|
||||||
|
"2b": {
|
||||||
"embedder.weight": lambda x: x.shape == (256128, 2048),
|
"embedder.weight": lambda x: x.shape == (256128, 2048),
|
||||||
"model.norm.weight": lambda x: x.shape == (2048,),
|
"model.norm.weight": lambda x: x.shape == (2048,),
|
||||||
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048),
|
||||||
|
|
@ -65,10 +126,22 @@ VALIDATIONS = {
|
||||||
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
|
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
|
||||||
"input_layernorm.weight": lambda x: x.shape == (2048,),
|
"input_layernorm.weight": lambda x: x.shape == (2048,),
|
||||||
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
|
||||||
|
},
|
||||||
|
"7b": {
|
||||||
|
"embedder.weight": lambda x: x.shape == (256128, 3072),
|
||||||
|
"model.norm.weight": lambda x: x.shape == (3072,),
|
||||||
|
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
|
||||||
|
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
|
||||||
|
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||||
|
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
|
||||||
|
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
|
||||||
|
"input_layernorm.weight": lambda x: x.shape == (3072,),
|
||||||
|
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def param_names():
|
def param_names(num_hidden_layers: int):
|
||||||
"""Return parameter names in the order they are expected for deserialization."""
|
"""Return parameter names in the order they are expected for deserialization."""
|
||||||
|
|
||||||
# note *weight_scaler params are ignored in the forward computation unless
|
# note *weight_scaler params are ignored in the forward computation unless
|
||||||
|
|
@ -79,62 +152,50 @@ def param_names():
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
names = [
|
names = [
|
||||||
("embedder.weight", ) * 2, # embedder_input_embedding (vocab=256000, model_dim=2048) -> (vocab=256128, model_dim=2048)
|
("embedder.weight", ) * 2, # embedder_input_embedding
|
||||||
("model.norm.weight", ) * 2 # final_norm_scale (model_dim=2048)
|
("model.norm.weight", ) * 2 # final_norm_scale
|
||||||
]
|
]
|
||||||
layer_params = [
|
layer_params = [
|
||||||
# TODO(austinvhuang): transpositions here ...
|
"self_attn.o_proj.weight", # attn_vec_einsum_w
|
||||||
"self_attn.o_proj.weight", # attn_vec_einsum_w (2048, 2048) -> (heads=8, model_dim=2048, qkv_dim=256)
|
"self_attn.qkv_proj.weight", # qkv_einsum_w
|
||||||
# # ( q_heads = 8 + kv = 2 ) x qkv_dim = 2560
|
"mlp.gate_proj.weight", # gating_einsum_w
|
||||||
"self_attn.qkv_proj.weight", # qkv_einsum_w (2560, 2048) -> (heads=8, qkv=3, qkv_dim=256, model_dim=2048)
|
|
||||||
# these are the same without any change
|
|
||||||
"mlp.gate_proj.weight", # gating_einsum_w (16384, 2048) => (gate/up=2, hidden=16384, model_dim=2048)
|
|
||||||
"mlp.up_proj.weight",
|
"mlp.up_proj.weight",
|
||||||
"mlp.down_proj.weight", # linear_w (2048, 16384) => (model_dim=2048, hidden=16384)
|
"mlp.down_proj.weight", # linear_w
|
||||||
"input_layernorm.weight", # pre_attention_norm_scale (model_dim=2048)
|
"input_layernorm.weight", # pre_attention_norm_scale
|
||||||
"post_attention_layernorm.weight", # pre_ffw_norm_scale (model_dim=2048)
|
"post_attention_layernorm.weight", # pre_ffw_norm_scale
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
for layer in range(18):
|
for layer in range(num_hidden_layers):
|
||||||
for layer_param in layer_params:
|
for layer_param in layer_params:
|
||||||
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
|
||||||
print("names:", names)
|
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
def convert_weights():
|
def convert_weights():
|
||||||
# TODO: parameterize paths as CLI args instead of hard coding them
|
model_type = args.model_type
|
||||||
output_file = "2bit-f32.sbs"
|
output_file = args.output_file
|
||||||
model_config = config.get_model_config("2b")
|
|
||||||
|
model_config = config.get_model_config(model_type)
|
||||||
model_config.dtype = "float32"
|
model_config.dtype = "float32"
|
||||||
|
model_config.tokenizer = args.tokenizer
|
||||||
## this turns on int8 quantization
|
|
||||||
# model_config.quant = "store_true"
|
|
||||||
model_config.tokenizer = "models/tokenizer.spm"
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
model = gemma_model.GemmaForCausalLM(model_config)
|
model = gemma_model.GemmaForCausalLM(model_config)
|
||||||
model.load_weights("models/gemma-2b-it.ckpt")
|
|
||||||
model_dict = dict(model.named_parameters())
|
|
||||||
|
|
||||||
for layer_name in model_dict:
|
|
||||||
## Make sure we're not silently having int8 quantization turned on.
|
|
||||||
print(layer_name, model_dict[layer_name].max())
|
|
||||||
assert(model_dict[layer_name].max() > 0.0)
|
|
||||||
|
|
||||||
param_order = param_names()
|
model.load_weights(args.weights)
|
||||||
|
model_dict = dict(model.named_parameters())
|
||||||
|
param_order = param_names(model_config.num_hidden_layers)
|
||||||
|
|
||||||
all_ok = True
|
all_ok = True
|
||||||
print("Checking transformations ...")
|
print("Checking transformations ...")
|
||||||
for name, layer_name in param_order:
|
for name, layer_name in param_order:
|
||||||
arr = model_dict[name].detach().numpy()
|
arr = model_dict[name].detach().numpy()
|
||||||
arr = TRANSFORMATIONS[layer_name](arr)
|
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||||
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||||
|
|
||||||
if check == "FAILED":
|
if check == "FAILED":
|
||||||
all_ok = False
|
all_ok = False
|
||||||
|
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
|
||||||
|
|
||||||
if all_ok:
|
if all_ok:
|
||||||
print("Writing parameters ...")
|
print("Writing parameters ...")
|
||||||
|
|
@ -142,19 +203,10 @@ def convert_weights():
|
||||||
with open(output_file, "wb") as bin_handle:
|
with open(output_file, "wb") as bin_handle:
|
||||||
for name, layer_name in param_order:
|
for name, layer_name in param_order:
|
||||||
arr = model_dict[name].detach().numpy()
|
arr = model_dict[name].detach().numpy()
|
||||||
arr = TRANSFORMATIONS[layer_name](arr)
|
arr = TRANSFORMATIONS[model_type][layer_name](arr)
|
||||||
check = "OK" if VALIDATIONS[layer_name](arr) else "FAILED"
|
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
|
||||||
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
print(f" {name : <60}{str(arr.shape) : <20}{check}")
|
||||||
|
arr.flatten().astype(np.float32).tofile(bin_handle)
|
||||||
if "gate_proj" in name:
|
|
||||||
gate = arr
|
|
||||||
elif "up_proj" in name:
|
|
||||||
up = arr
|
|
||||||
f = np.concatenate([gate, up])
|
|
||||||
print (f.shape)
|
|
||||||
f.flatten().astype(np.float32).tofile(bin_handle)
|
|
||||||
else:
|
|
||||||
arr.flatten().astype(np.float32).tofile(bin_handle)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue