mirror of https://github.com/google/gemma.cpp.git
Add conversion tool for HF safetensors to gemma.cpp for PaliGemma.
PiperOrigin-RevId: 725990158
This commit is contained in:
parent
c495b25995
commit
f173aa776e
22
README.md
22
README.md
|
|
@ -349,8 +349,26 @@ and not a pre-trained model (any model with a `-pt` suffix).
|
||||||
|
|
||||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||||
|
|
||||||
See compression/convert_weights.py to convert a pytorch checkpint. (The code may
|
For PaliGemma (1 and 2) checkpoints, you can use
|
||||||
need updates to work with Gemma-2 models.)
|
python/convert_from_safetensors.py to convert from safetensors format (tested
|
||||||
|
with building via bazel). For an adapter model, you will likely need to call
|
||||||
|
merge_and_unload() to convert the adapter model to a single-file format before
|
||||||
|
converting it.
|
||||||
|
|
||||||
|
Here is how to use it using a bazel build of the compression library assuming
|
||||||
|
locally installed (venv) torch, numpy, safetensors, absl-py, etc.:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
bazel build //compression/python:compression
|
||||||
|
BAZEL_OUTPUT_DIR="${PWD}/bazel-bin/compression"
|
||||||
|
python3 -c "import site; print(site.getsitepackages())"
|
||||||
|
# Use your sites-packages file here:
|
||||||
|
ln -s $BAZEL_OUTPUT_DIR [...]/site-packages/compression
|
||||||
|
python3 python/convert_from_safetensors.py --load_path [...].safetensors.index.json
|
||||||
|
```
|
||||||
|
|
||||||
|
See also compression/convert_weights.py for a slightly older option to convert a
|
||||||
|
pytorch checkpoint. (The code may need updates to work with Gemma-2 models.)
|
||||||
|
|
||||||
**What are some easy ways to make the model run faster?**
|
**What are some easy ways to make the model run faster?**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ py_binary(
|
||||||
deps = [
|
deps = [
|
||||||
":gemma",
|
":gemma",
|
||||||
"@python_deps//absl_py",
|
"@python_deps//absl_py",
|
||||||
# placeholder forabsl/flags
|
# placeholder for absl/flags
|
||||||
"@compression_deps//numpy",
|
"@compression_deps//numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,436 @@
|
||||||
|
# Copyright 2025 Google LLC
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Convert a PaliGemma[1/2] model from SafeTensors to gemma.cpp format."""
|
||||||
|
# Tested with:
|
||||||
|
# - PG1: huggingface.co/google/paligemma-3b-pt-224
|
||||||
|
# - PG1: huggingface.co/merve/paligemma_vqav2
|
||||||
|
# - PG2: huggingface.co/google/paligemma2-3b-pt-448
|
||||||
|
# - PG2: huggingface.co/merve/paligemma2-3b-vqav2
|
||||||
|
# The last one above is a Lora model, so the merged weights were saved using:
|
||||||
|
# model_name = "google/paligemma2-3b-pt-448"
|
||||||
|
# lora_weights_path = "merve/paligemma2-3b-vqav2"
|
||||||
|
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
# model = PeftModel.from_pretrained(model, lora_weights_path)
|
||||||
|
# model = model.merge_and_unload()
|
||||||
|
# model.save_pretrained("/tmp/lora-model")
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import numpy as np
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from compression.python import compression
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_f32(x: np.ndarray) -> np.ndarray:
|
||||||
|
"""Flattens an array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flattened array.
|
||||||
|
"""
|
||||||
|
return x.ravel().astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_scale(x: np.ndarray) -> float:
|
||||||
|
"""Rescales weight tensor to fit max magnitude within 1.875.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scale value (1.0 means no rescaling).
|
||||||
|
"""
|
||||||
|
magnitude = np.max(np.abs(x))
|
||||||
|
return max(1.0, magnitude / 1.875)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_float_param(param_name: str) -> bool:
|
||||||
|
for prefix in ["img_pos_emb", "attn_out_b", "linear_0_b", "linear_1_b",
|
||||||
|
"qkv_ein_b", "img_emb_bias", "img_head_bias"]:
|
||||||
|
if param_name.startswith(prefix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bf16_param(param_name: str) -> bool:
|
||||||
|
for prefix in ["pre_", "post_", "c_", "img_head_kernel"]:
|
||||||
|
if param_name.startswith(prefix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Layernorm names are slightly confusing in HF transformers between versions.
|
||||||
|
# Gemma layernorms:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
|
||||||
|
# input_layernorm attn residual post_attention_layernorm mlp residual
|
||||||
|
# Gemma2 layernorms:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py
|
||||||
|
# input_layernorm attn post_attention_layernorm residual
|
||||||
|
# pre_feedforward_layernorm mlp post_feedforward_layernorm residual
|
||||||
|
# Note that post_attention_layernorm denotes something different.
|
||||||
|
# For comparison, the Big Vision Gemma2 keeps the same name for the same norm:
|
||||||
|
# pre_attention_norm attn [post_attention_norm] residual
|
||||||
|
# pre_ffw_norm mlp [post_ffw_norm] residual
|
||||||
|
|
||||||
|
|
||||||
|
# Tuples correspond to (transformers-name, shape, sbs-name).
|
||||||
|
# The qkv-einsum weights are part of llm-layers but are handled separately and
|
||||||
|
# thus not included in the list.
|
||||||
|
def _get_layer_config(dims: Dict[str, Any]):
|
||||||
|
"""Returns a dictionary of layer configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims: A dictionary of (mostly) dimension values.
|
||||||
|
Returns:
|
||||||
|
A dictionary of layer configurations.
|
||||||
|
"""
|
||||||
|
model_dim = dims["model_dim"]
|
||||||
|
hidden_dim = dims["hidden_dim"]
|
||||||
|
vit_seq_len = dims["vit_seq_len"]
|
||||||
|
config = {
|
||||||
|
"llm-non-layers": [
|
||||||
|
("language_model.model.embed_tokens.weight", (257152, model_dim), "c_embedding"),
|
||||||
|
("language_model.model.norm.weight", (model_dim,), "c_final_norm"),
|
||||||
|
],
|
||||||
|
"llm-layers": [
|
||||||
|
("language_model.model.layers.%d.mlp.down_proj.weight", (model_dim, hidden_dim), "linear_w"),
|
||||||
|
],
|
||||||
|
"img-non-layers": [
|
||||||
|
("vision_tower.vision_model.post_layernorm.bias", (1152,), "enc_norm_bias"),
|
||||||
|
("vision_tower.vision_model.post_layernorm.weight", (1152,), "enc_norm_scale"),
|
||||||
|
("vision_tower.vision_model.embeddings.patch_embedding.bias", (1152,), "img_emb_bias"),
|
||||||
|
("vision_tower.vision_model.embeddings.patch_embedding.weight", (1152, 14, 14, 3), "img_emb_kernel"),
|
||||||
|
("multi_modal_projector.linear.bias", (model_dim,), "img_head_bias"),
|
||||||
|
("multi_modal_projector.linear.weight", (model_dim, 1152), "img_head_kernel"),
|
||||||
|
("vision_tower.vision_model.embeddings.position_embedding.weight", (vit_seq_len, 1152), "img_pos_emb"),
|
||||||
|
],
|
||||||
|
"img-layers": [
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.layer_norm1.bias", (1152,), "ln_0_bias"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.layer_norm1.weight", (1152,), "ln_0_scale"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.layer_norm2.bias", (1152,), "ln_1_bias"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.layer_norm2.weight", (1152,), "ln_1_scale"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.bias", (4304,), "linear_0_b"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.mlp.fc1.weight", (4304, 1152), "linear_0_w"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.bias", (1152,), "linear_1_b"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.mlp.fc2.weight", (1152, 4304), "linear_1_w"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.bias", (1152,), "attn_out_b"),
|
||||||
|
("vision_tower.vision_model.encoder.layers.%d.self_attn.out_proj.weight", (1152, 16 * 72), "attn_out_w"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
if dims["has_post_norm"]: # See longer comment above.
|
||||||
|
config["llm-layers"] += [
|
||||||
|
("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"),
|
||||||
|
("language_model.model.layers.%d.pre_feedforward_layernorm.weight", (model_dim,), "pre_ff_ns"),
|
||||||
|
("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "post_att_ns"),
|
||||||
|
("language_model.model.layers.%d.post_feedforward_layernorm.weight", (model_dim,), "post_ff_ns"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
config["llm-layers"] += [
|
||||||
|
("language_model.model.layers.%d.input_layernorm.weight", (model_dim,), "pre_att_ns"),
|
||||||
|
("language_model.model.layers.%d.post_attention_layernorm.weight", (model_dim,), "pre_ff_ns"),
|
||||||
|
]
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dimensions(params):
|
||||||
|
"""Returns a dictionary of dimension values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: A dictionary with parameters.
|
||||||
|
Returns:
|
||||||
|
A dictionary of dimension values.
|
||||||
|
"""
|
||||||
|
dims = {}
|
||||||
|
# For PG1 and PG2-{3B,10B} head_dim is 256, would need update for PG2-28B.
|
||||||
|
# Unfortunately not easily available in any of the input sizes.
|
||||||
|
dims["head_dim"] = 256
|
||||||
|
dims["model_dim"] = params["multi_modal_projector.linear.bias"].shape[0]
|
||||||
|
dims["hidden_dim"] = params[
|
||||||
|
"language_model.model.layers.0.mlp.gate_proj.weight"
|
||||||
|
].shape[0]
|
||||||
|
dims["num_heads"] = (
|
||||||
|
params["language_model.model.layers.0.self_attn.q_proj.weight"].shape[0]
|
||||||
|
// dims["head_dim"]
|
||||||
|
)
|
||||||
|
dims["vit_seq_len"] = params[
|
||||||
|
"vision_tower.vision_model.embeddings.position_embedding.weight"
|
||||||
|
].shape[0]
|
||||||
|
dims["num_llm_layers"] = len(
|
||||||
|
set([k for k in params.keys() if "input_layernorm.weight" in k])
|
||||||
|
)
|
||||||
|
dims["has_post_norm"] = (
|
||||||
|
"language_model.model.layers.0.post_feedforward_layernorm.weight"
|
||||||
|
in params
|
||||||
|
)
|
||||||
|
return dims
|
||||||
|
|
||||||
|
|
||||||
|
def export_paligemma_sbs(
|
||||||
|
load_path: str,
|
||||||
|
csv_file: str,
|
||||||
|
sbs_file: str,
|
||||||
|
) -> None:
|
||||||
|
"""Exports sbs file from paligemma safetensors file(s)."""
|
||||||
|
|
||||||
|
# If this is a multi-part checkpoint, get the list of files from the json.
|
||||||
|
if load_path.endswith(".json"):
|
||||||
|
with open(load_path, "r") as f:
|
||||||
|
j_obj = json.load(f)
|
||||||
|
files = list(set(j_obj["weight_map"].values()))
|
||||||
|
files = [os.path.join(os.path.dirname(load_path), f) for f in files]
|
||||||
|
else:
|
||||||
|
files = [load_path]
|
||||||
|
|
||||||
|
# Read the parameters from the files.
|
||||||
|
params = {}
|
||||||
|
for file in files:
|
||||||
|
with safetensors.safe_open(file, framework="pt") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
params[k] = f.get_tensor(k)
|
||||||
|
print(k, params[k].shape, params[k].view(-1)[0].item())
|
||||||
|
|
||||||
|
# See https://tinyurl.com/paligemmavocab - HF transformers extends the
|
||||||
|
# embedding matrix by 64. Undo that here.
|
||||||
|
params["language_model.model.embed_tokens.weight"] = params[
|
||||||
|
"language_model.model.embed_tokens.weight"
|
||||||
|
][:-64]
|
||||||
|
|
||||||
|
# Initialize a few things.
|
||||||
|
writer = compression.SbsWriter(compression.CompressorMode.NO_TOC)
|
||||||
|
metadata = []
|
||||||
|
scales = {}
|
||||||
|
dims = _get_dimensions(params)
|
||||||
|
layer_config = _get_layer_config(dims)
|
||||||
|
|
||||||
|
# Adds a parameter with expected shape to the writer.
|
||||||
|
def add_data(param_name, data, expected_shape, sbs_name, layer_index=None):
|
||||||
|
# Check shape.
|
||||||
|
if not isinstance(expected_shape, tuple):
|
||||||
|
expected_shape = (expected_shape,)
|
||||||
|
print(f"Writing {param_name} with shape {data.shape} e:{expected_shape}")
|
||||||
|
assert data.shape == expected_shape, param_name
|
||||||
|
|
||||||
|
# Here we assume that the read data is a torch tensor and then convert it to
|
||||||
|
# a numpy array.
|
||||||
|
assert isinstance(data, torch.Tensor)
|
||||||
|
data = data.to(torch.float32).numpy()
|
||||||
|
data = np.array(data)
|
||||||
|
|
||||||
|
# Add the layer index to the param name and sbs name if needed.
|
||||||
|
if layer_index is not None:
|
||||||
|
param_name = param_name % layer_index
|
||||||
|
sbs_name = sbs_name + f"_{layer_index}"
|
||||||
|
|
||||||
|
# Flatten the data and get scale.
|
||||||
|
value = flatten_f32(data)
|
||||||
|
scale = compute_scale(value)
|
||||||
|
both_names = param_name + "::" + sbs_name
|
||||||
|
print(f"Param {both_names} has scale {scale}")
|
||||||
|
metadata.append((both_names, data.dtype, data.shape, scale))
|
||||||
|
|
||||||
|
# Determine the type as which to insert.
|
||||||
|
if _is_float_param(sbs_name):
|
||||||
|
insert = writer.insert_float # Insert as float.
|
||||||
|
print(f"Inserting {both_names} as float (f32) (no scaling)")
|
||||||
|
elif _is_bf16_param(sbs_name) or param_name.startswith("vision_tower"):
|
||||||
|
insert = writer.insert_bf16 # Insert as BF16.
|
||||||
|
print(f"Inserting {both_names} as BF16 (no scaling)")
|
||||||
|
else:
|
||||||
|
insert = writer.insert_sfp # Insert as SFP.
|
||||||
|
# Assumes that all scales are 1.0 for SFP. Consider adding scales.
|
||||||
|
# They would still need to be written, but would be collected here.
|
||||||
|
assert scale == 1.0, f"Scale for {both_names} is not 1.0"
|
||||||
|
if scale != 1.0:
|
||||||
|
value = value / scale
|
||||||
|
scales[sbs_name] = scale # Unused at the moment.
|
||||||
|
print(f"Inserting {both_names} as SFP with scale {scale}")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Add the data to the writer.
|
||||||
|
insert(sbs_name, value)
|
||||||
|
|
||||||
|
def add_qkv_einsum(i): # Handle qkv for layer i.
|
||||||
|
name = "language_model.model.layers.%d.self_attn.q_proj.weight" # (N*H, D)
|
||||||
|
q_i = params.pop(name % i)
|
||||||
|
(nh, d) = q_i.shape
|
||||||
|
h = dims["head_dim"]
|
||||||
|
n = dims["num_heads"]
|
||||||
|
assert nh == n * h
|
||||||
|
assert dims["model_dim"] == d
|
||||||
|
q_i = q_i.reshape(n, h, d)
|
||||||
|
name = "language_model.model.layers.%d.self_attn.k_proj.weight" # (K*H, D)
|
||||||
|
k_i = params.pop(name % i)
|
||||||
|
kh = k_i.shape[0]
|
||||||
|
k = kh // h
|
||||||
|
assert k_i.shape[1] == d
|
||||||
|
k_i = k_i.reshape(k, h, d)
|
||||||
|
name = "language_model.model.layers.%d.self_attn.v_proj.weight" # (K*H, D)
|
||||||
|
v_i = params.pop(name % i)
|
||||||
|
assert v_i.shape[0] == kh
|
||||||
|
assert v_i.shape[1] == d
|
||||||
|
v_i = v_i.reshape(k, h, d)
|
||||||
|
# Stack and reshape KV to interleave (k,v), (k,v), ...
|
||||||
|
stacked = torch.stack((k_i, v_i), dim=0) # (2, K, H, D)
|
||||||
|
transposed = stacked.transpose(0, 1) # (K, 2, H, D)
|
||||||
|
reshaped = transposed.reshape(2 * k, h, d) # (2K, H, D)
|
||||||
|
# Concatenate Q and KV to get the full qkv.
|
||||||
|
qkv_i = torch.cat([q_i, reshaped], dim=0)
|
||||||
|
name = "language_model.model.layers.%d.self_attn.qkv_proj.weight"
|
||||||
|
expected_shape = (n + 2 * k, h, d) # (N+2K, H, D)
|
||||||
|
sbs_name = "qkv_ein"
|
||||||
|
add_data(name, qkv_i, expected_shape, sbs_name, i)
|
||||||
|
|
||||||
|
def add_att_einsum(i): # Handle att_ein for layer i.
|
||||||
|
name = "language_model.model.layers.%d.self_attn.o_proj.weight" # (D, N*H)
|
||||||
|
o_i = params.pop(name % i)
|
||||||
|
(d, nh) = o_i.shape
|
||||||
|
h = dims["head_dim"]
|
||||||
|
n = dims["num_heads"]
|
||||||
|
assert nh == n * h
|
||||||
|
o_i = o_i.reshape(d, n, h).permute(1, 0, 2) # (D, N, H) -> (N, D, H)
|
||||||
|
expected_shape = (n, d, h)
|
||||||
|
sbs_name = "att_ein"
|
||||||
|
add_data(name, o_i, expected_shape, sbs_name, i)
|
||||||
|
|
||||||
|
# Join gate and up projection weights to gating_einsum for layer i.
|
||||||
|
def add_gating_einsum(i):
|
||||||
|
name = "language_model.model.layers.%d.mlp.gate_proj.weight"
|
||||||
|
gate_i = params.pop(name % i)
|
||||||
|
f, d = gate_i.shape
|
||||||
|
assert dims["hidden_dim"] == f
|
||||||
|
assert dims["model_dim"] == d
|
||||||
|
name = "language_model.model.layers.%d.mlp.up_proj.weight"
|
||||||
|
up_i = params.pop(name % i)
|
||||||
|
assert up_i.shape == gate_i.shape
|
||||||
|
gating_einsum_i = torch.stack([gate_i, up_i], dim=0)
|
||||||
|
name = "language_model.model.layers.%d.mlp.gating_einsum.weight"
|
||||||
|
expected_shape = (2, f, d)
|
||||||
|
sbs_name = "gating_ein"
|
||||||
|
add_data(name, gating_einsum_i, expected_shape, sbs_name, i)
|
||||||
|
|
||||||
|
# Handle the q and kv einsum parts for layer i in the ViT - merge into qkv.
|
||||||
|
def add_vit_qkv_einsum(i):
|
||||||
|
# Weights first.
|
||||||
|
prefix = "vision_tower.vision_model.encoder.layers.%d.self_attn"
|
||||||
|
name = prefix + ".q_proj.weight" # (16 * 72, 1152)
|
||||||
|
q_i = params.pop(name % i)
|
||||||
|
q_i = q_i.reshape(16, 72, 1152)
|
||||||
|
name = prefix + ".k_proj.weight" # (16 * 72, 1152)
|
||||||
|
k_i = params.pop(name % i)
|
||||||
|
k_i = k_i.reshape(16, 72, 1152)
|
||||||
|
name = prefix + ".v_proj.weight" # (16 * 72, 1152)
|
||||||
|
v_i = params.pop(name % i)
|
||||||
|
v_i = v_i.reshape(16, 72, 1152)
|
||||||
|
qkv_i, shape = torch.stack([q_i, k_i, v_i], dim=1), (16, 3, 72, 1152)
|
||||||
|
name = prefix + ".qkv_proj.weight"
|
||||||
|
sbs_name = "qkv_ein_w"
|
||||||
|
add_data(name, qkv_i, shape, sbs_name, i)
|
||||||
|
# Now the biases.
|
||||||
|
name = prefix + ".q_proj.bias" # (16 * 72)
|
||||||
|
q_i = params.pop(name % i)
|
||||||
|
q_i = q_i.reshape(16, 72)
|
||||||
|
name = prefix + ".k_proj.bias" # (16 * 72)
|
||||||
|
k_i = params.pop(name % i)
|
||||||
|
k_i = k_i.reshape(16, 72)
|
||||||
|
name = prefix + ".v_proj.bias" # (16 * 72)
|
||||||
|
v_i = params.pop(name % i)
|
||||||
|
v_i = v_i.reshape(16, 72)
|
||||||
|
qkv_i, shape = torch.stack([q_i, k_i, v_i], dim=1), (16, 3, 72)
|
||||||
|
name = prefix + ".qkv_proj.bias"
|
||||||
|
sbs_name = "qkv_ein_b"
|
||||||
|
add_data(name, qkv_i, shape, sbs_name, i)
|
||||||
|
|
||||||
|
# Handle the image embedding kernel transpose.
|
||||||
|
name = "vision_tower.vision_model.embeddings.patch_embedding.weight"
|
||||||
|
assert params[name].shape == (1152, 3, 14, 14,)
|
||||||
|
params[name] = params[name].permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# Add the non-layer params.
|
||||||
|
for name, shape, sbs_name in (
|
||||||
|
layer_config["llm-non-layers"] + layer_config["img-non-layers"]
|
||||||
|
):
|
||||||
|
add_data(name, params.pop(name), shape, sbs_name)
|
||||||
|
|
||||||
|
# Go through the LLM layers and add the weights.
|
||||||
|
for i in range(dims["num_llm_layers"]):
|
||||||
|
add_att_einsum(i)
|
||||||
|
add_gating_einsum(i)
|
||||||
|
for name, shape, sbs_name in layer_config["llm-layers"]:
|
||||||
|
add_data(name, params.pop(name % i), shape, sbs_name, i)
|
||||||
|
add_qkv_einsum(i)
|
||||||
|
|
||||||
|
# Go through the Vit layers and add the weights.
|
||||||
|
for i in range(27):
|
||||||
|
for name, shape, sbs_name in layer_config["img-layers"]:
|
||||||
|
add_data(name, params.pop(name % i), shape, sbs_name, i)
|
||||||
|
add_vit_qkv_einsum(i)
|
||||||
|
|
||||||
|
assert not params, "Some params were not used: %s" % params.keys()
|
||||||
|
|
||||||
|
# Write everything to the sbs file.
|
||||||
|
writer.write(sbs_file)
|
||||||
|
|
||||||
|
# Write the metadata for manual inspection.
|
||||||
|
with open(csv_file, "w") as csv_handle:
|
||||||
|
csv.writer(csv_handle).writerows(metadata)
|
||||||
|
|
||||||
|
|
||||||
|
_LOAD_PATH = flags.DEFINE_string(
|
||||||
|
"load_path",
|
||||||
|
"",
|
||||||
|
"Path to the safetensors index.json file to read",
|
||||||
|
)
|
||||||
|
_METADATA_FILE = flags.DEFINE_string(
|
||||||
|
"metadata_file",
|
||||||
|
"/tmp/gemmacpp.csv",
|
||||||
|
"Path to the metadata file to write",
|
||||||
|
)
|
||||||
|
_SBS_FILE = flags.DEFINE_string(
|
||||||
|
"sbs_file",
|
||||||
|
"/tmp/gemmacpp.sbs",
|
||||||
|
"Path to the sbs file to write",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: Sequence[str]) -> None:
|
||||||
|
if len(argv) > 1:
|
||||||
|
raise app.UsageError("Too many command-line arguments.")
|
||||||
|
logging.use_python_logging()
|
||||||
|
logging.set_verbosity(logging.INFO)
|
||||||
|
load_path = _LOAD_PATH.value
|
||||||
|
metadata_file = _METADATA_FILE.value
|
||||||
|
sbs_file = _SBS_FILE.value
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"\n====\nReading from %s and writing to %s\n====", load_path, sbs_file
|
||||||
|
)
|
||||||
|
export_paligemma_sbs(load_path, metadata_file, sbs_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(main)
|
||||||
Loading…
Reference in New Issue