From f82337169160c7ec13cdd4b72fdc96170aa87138 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Jul 2024 04:16:14 -0700 Subject: [PATCH] Cleanup: move util/compress and convert_weights to compression/ Also remove unused models/, lint convert_weights PiperOrigin-RevId: 649613088 --- BUILD.bazel | 18 -- CMakeLists.txt | 2 +- DEVELOPERS.md | 10 +- compression/BUILD | 18 ++ {util => compression}/compress_weights.cc | 2 +- compression/convert_weights.py | 209 +++++++++++++++++++++ compression/python/compression_clif_aux.cc | 2 +- models/.gitignore | 0 util/convert_weights.py | 198 ------------------- 9 files changed, 235 insertions(+), 224 deletions(-) rename {util => compression}/compress_weights.cc (99%) create mode 100644 compression/convert_weights.py delete mode 100644 models/.gitignore delete mode 100644 util/convert_weights.py diff --git a/BUILD.bazel b/BUILD.bazel index 48dc3df..90ec0c0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -259,24 +259,6 @@ cc_binary( ], ) -cc_binary( - name = "compress_weights", - srcs = ["util/compress_weights.cc"], - deps = [ - ":args", - ":common", - ":gemma_lib", - ":weights", - ":weights_raw", - # Placeholder for internal dep, do not remove., - "//compression:compress", - "@hwy//:hwy", - "@hwy//:nanobenchmark", - "@hwy//:profiler", - "@hwy//:thread_pool", - ], -) - cc_binary( name = "single_benchmark", srcs = ["evals/benchmark.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b4fc49..819bec7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -169,5 +169,5 @@ endif() # GEMMA_ENABLE_TESTS ## Tools -add_executable(compress_weights util/compress_weights.cc) +add_executable(compress_weights compression/compress_weights.cc) target_link_libraries(compress_weights libgemma hwy hwy_contrib) diff --git a/DEVELOPERS.md b/DEVELOPERS.md index c1c6505..fdebad4 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -94,11 +94,11 @@ If starting with Keras, first run this script to convert to Pytorch: https://github.com/keras-team/keras-nlp/blob/master/tools/gemma/export_gemma_to_torch_xla.py From Pytorch, use the following script to generate uncompressed weights: -https://github.com/google/gemma.cpp/blob/dev/util/convert_weights.py +https://github.com/google/gemma.cpp/blob/dev/compression/convert_weights.py -Then run gemma/compress_weights.cc (Bazel target :compress_weights), specifying -the resulting file as `--weights` and the desired .sbs name as the -`--compressed_weights`. +Then run `compression/compress_weights.cc` (Bazel target +`compression:compress_weights`), specifying the resulting file as `--weights` +and the desired .sbs name as the `--compressed_weights`. ## Compile-Time Flags (Advanced) @@ -192,7 +192,7 @@ transforms we apply to Gemma via Copybara. ## Debugging At the first sign of incorrect or unexpected results, we recommend running with -ASan/MSan enabled. When using blaze/bazel, you can add `--config=asan` or +ASan/MSan enabled. When using bazel, you can add `--config=asan` or `--config=msan-track-origins` to the build command. In addition to their checks for memory overruns or uninitialized memory, we also enable debug-only asserts in Gemma.cpp for those build configurations. diff --git a/compression/BUILD b/compression/BUILD index c8b077b..962573f 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -180,3 +180,21 @@ cc_library( "@hwy//hwy/contrib/sort:vqsort", ], ) + +cc_binary( + name = "compress_weights", + srcs = ["compress_weights.cc"], + deps = [ + ":compress", + # Placeholder for internal dep, do not remove., + "//third_party/gemma_cpp:args", + "//third_party/gemma_cpp:common", + "//third_party/gemma_cpp:gemma_lib", + "//third_party/gemma_cpp:weights", + "//third_party/gemma_cpp:weights_raw", + "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:profiler", + "@hwy//:thread_pool", + ], +) diff --git a/util/compress_weights.cc b/compression/compress_weights.cc similarity index 99% rename from util/compress_weights.cc rename to compression/compress_weights.cc index cc14c42..b2ca0b0 100644 --- a/util/compress_weights.cc +++ b/compression/compress_weights.cc @@ -19,7 +19,7 @@ // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ - "util/compress_weights.cc" // NOLINT + "compression/compress_weights.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. #include "compression/compress-inl.h" diff --git a/compression/convert_weights.py b/compression/convert_weights.py new file mode 100644 index 0000000..3ba1642 --- /dev/null +++ b/compression/convert_weights.py @@ -0,0 +1,209 @@ +# Copyright 2024 Google LLC +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Converts pytorch to f32 for use by compress_weights.cc.""" + +import argparse +import collections +import os +from gemma import config +from gemma import model as gemma_model +import numpy as np +import torch + +# Requires torch 2.2 and gemma package from +# https://github.com/google/gemma_pytorch + + +def check_file_exists(value): + if not os.path.exists(str(value)): + raise argparse.ArgumentTypeError( + "The file %s does not appear to exist." % value + ) + return value + + +def check_model_types(value): + if str(value).lower() not in ["2b", "7b"]: + raise argparse.ArgumentTypeError( + "Model type value %s is not in [2b, 7b]." % value + ) + return value + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--tokenizer", + dest="tokenizer", + default="models/tokenizer.spm", + help="Location of tokenizer file (.model or .spm)", + type=check_file_exists, +) + +parser.add_argument( + "--weights", + dest="weights", + default="models/gemma-2b-it.ckpt", + help="Location of input checkpoint file (.ckpt)", + type=check_file_exists, +) + +parser.add_argument( + "--output_file", + dest="output_file", + default="2bit-f32.sbs", + help="Location to write converted weights", + type=str, +) + +parser.add_argument( + "--model_type", + dest="model_type", + default="2b", + help="Model size / type (2b, 7b)", + type=check_model_types, +) + +args = parser.parse_args() + + +TRANSFORMATIONS = { + "2b": collections.defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: x, + "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), + "self_attn.o_proj.weight": lambda x: x.reshape( + (2048, 8, 256) + ).transpose([1, 0, 2]), + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.down_proj.weight": lambda x: x, + }, + ), + "7b": collections.defaultdict( + lambda: lambda x: x, + { + "embedder.weight": lambda x: x, + "self_attn.qkv_proj.weight": lambda x: x.reshape( + (3, 16, 256, 3072) + ).transpose([1, 0, 2, 3]), + "self_attn.o_proj.weight": lambda x: x.reshape( + (3072, 16, 256) + ).transpose([1, 0, 2]), + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], + "mlp.down_proj.weight": lambda x: x, + }, + ), +} + +VALIDATIONS = { + "2b": { + "embedder.weight": lambda x: x.shape == (256000, 2048), + "model.norm.weight": lambda x: x.shape == (2048,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), + "input_layernorm.weight": lambda x: x.shape == (2048,), + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), + }, + "7b": { + "embedder.weight": lambda x: x.shape == (256000, 3072), + "model.norm.weight": lambda x: x.shape == (3072,), + "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), + "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), + "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), + "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), + "input_layernorm.weight": lambda x: x.shape == (3072,), + "post_attention_layernorm.weight": lambda x: x.shape == (3072,), + }, +} + + +def param_names(num_hidden_layers: int): + """Return parameter names in the order they are expected for deserialization.""" + + # note *weight_scaler params are ignored in the forward computation unless + # quantization is being used. + # + # since we are working with the full precision weights as input, don't + # include these in the parameters being iterated over. + + names = [ + ("embedder.weight",) * 2, # embedder_input_embedding + ("model.norm.weight",) * 2, # final_norm_scale + ] + layer_params = [ + "self_attn.o_proj.weight", # attn_vec_einsum_w + "self_attn.qkv_proj.weight", # qkv_einsum_w + "mlp.gate_proj.weight", # gating_einsum_w + "mlp.up_proj.weight", + "mlp.down_proj.weight", # linear_w + "input_layernorm.weight", # pre_attention_norm_scale + "post_attention_layernorm.weight", # pre_ffw_norm_scale + ] + for layer in range(num_hidden_layers): + for layer_param in layer_params: + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] + return names + + +def convert_weights(): + """Main function; loads weights, runs transformations, writes f32.""" + model_type = args.model_type + output_file = args.output_file + + model_config = config.get_model_config(model_type) + model_config.dtype = "float32" + model_config.tokenizer = args.tokenizer + device = torch.device("cpu") + torch.set_default_dtype(torch.float) + model = gemma_model.GemmaForCausalLM(model_config) + + model.load_weights(args.weights) + model.to(device).eval() + + model_dict = dict(model.named_parameters()) + param_order = param_names(model_config.num_hidden_layers) + + all_ok = True + print("Checking transformations ...") + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" + + if check == "FAILED": + all_ok = False + print(f" {name : <60}{str(arr.shape) : <20}{check}") + + if all_ok: + print("Writing parameters ...") + with open(output_file, "wb") as bin_handle: + for name, layer_name in param_order: + arr = model_dict[name].detach().numpy() + arr = TRANSFORMATIONS[model_type][layer_name](arr) + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" + print(f" {name : <60}{str(arr.shape) : <20}{check}") + arr.flatten().astype(np.float32).tofile(bin_handle) + + +if __name__ == "__main__": + convert_weights() + print("Done") diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index c017e64..6c938d9 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -2,7 +2,7 @@ #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ - "third_party/gemma_cpp/compression/python/compression_clif_aux.cc" // NOLINT + "compression/python/compression_clif_aux.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. #include "compression/compress-inl.h" diff --git a/models/.gitignore b/models/.gitignore deleted file mode 100644 index e69de29..0000000 diff --git a/util/convert_weights.py b/util/convert_weights.py deleted file mode 100644 index 0211c01..0000000 --- a/util/convert_weights.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2024 Google LLC -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from collections import defaultdict -import torch -from gemma import config -from gemma import model as gemma_model -import numpy as np -import argparse -import os - -# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch - -def check_file_exists(value): - if not os.path.exists(str(value)): - raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) - return value - - -def check_model_types(value): - if str(value).lower() not in ["2b", "7b"]: - raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) - return value - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--tokenizer", - dest="tokenizer", - default="models/tokenizer.spm", - help="Location of tokenizer file (.model or .spm)", - type=check_file_exists, -) - -parser.add_argument( - "--weights", - dest="weights", - default="models/gemma-2b-it.ckpt", - help="Location of input checkpoint file (.ckpt)", - type=check_file_exists, -) - -parser.add_argument( - "--output_file", - dest="output_file", - default="2bit-f32.sbs", - help="Location to write converted weights", - type=str, -) - -parser.add_argument( - "--model_type", - dest="model_type", - default="2b", - help="Model size / type (2b, 7b)", - type=check_model_types, -) - -args = parser.parse_args() - - -TRANSFORMATIONS = { - "2b":defaultdict( - lambda: lambda x: x, - { - "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), - "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.down_proj.weight": lambda x: x, - } - ), - "7b":defaultdict( - lambda: lambda x: x, - { - "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), - "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], - "mlp.down_proj.weight": lambda x: x, - } - ), -} - -VALIDATIONS = { - "2b": { - "embedder.weight": lambda x: x.shape == (256000, 2048), - "model.norm.weight": lambda x: x.shape == (2048,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), - "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), - "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), - "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), - "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), - "input_layernorm.weight": lambda x: x.shape == (2048,), - "post_attention_layernorm.weight": lambda x: x.shape == (2048,), - }, - "7b": { - "embedder.weight": lambda x: x.shape == (256000, 3072), - "model.norm.weight": lambda x: x.shape == (3072,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), - "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), - "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), - "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), - "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), - "input_layernorm.weight": lambda x: x.shape == (3072,), - "post_attention_layernorm.weight": lambda x: x.shape == (3072,), - }, -} - - -def param_names(num_hidden_layers: int): - """Return parameter names in the order they are expected for deserialization.""" - - # note *weight_scaler params are ignored in the forward computation unless - # quantization is being used. - # - # since we are working with the full precision weights as input, don't - # include these in the parameters being iterated over. - - # fmt: off - names = [ - ("embedder.weight", ) * 2, # embedder_input_embedding - ("model.norm.weight", ) * 2 # final_norm_scale - ] - layer_params = [ - "self_attn.o_proj.weight", # attn_vec_einsum_w - "self_attn.qkv_proj.weight", # qkv_einsum_w - "mlp.gate_proj.weight", # gating_einsum_w - "mlp.up_proj.weight", - "mlp.down_proj.weight", # linear_w - "input_layernorm.weight", # pre_attention_norm_scale - "post_attention_layernorm.weight", # pre_ffw_norm_scale - ] - # fmt: on - for layer in range(num_hidden_layers): - for layer_param in layer_params: - names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] - return names - - -def convert_weights(): - model_type = args.model_type - output_file = args.output_file - - model_config = config.get_model_config(model_type) - model_config.dtype = "float32" - model_config.tokenizer = args.tokenizer - device = torch.device("cpu") - torch.set_default_dtype(torch.float) - model = gemma_model.GemmaForCausalLM(model_config) - - model.load_weights(args.weights) - model.to(device).eval() - - model_dict = dict(model.named_parameters()) - param_order = param_names(model_config.num_hidden_layers) - - all_ok = True - print("Checking transformations ...") - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[model_type][layer_name](arr) - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" - - if check == "FAILED": - all_ok = False - print(f" {name : <60}{str(arr.shape) : <20}{check}") - - if all_ok: - print("Writing parameters ...") - gate = None - with open(output_file, "wb") as bin_handle: - for name, layer_name in param_order: - arr = model_dict[name].detach().numpy() - arr = TRANSFORMATIONS[model_type][layer_name](arr) - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" - print(f" {name : <60}{str(arr.shape) : <20}{check}") - arr.flatten().astype(np.float32).tofile(bin_handle) - - -if __name__ == "__main__": - convert_weights() - print("Done")