diff --git a/BUILD.bazel b/BUILD.bazel index 152657e..27c6cd7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -115,3 +115,20 @@ cc_binary( "@hwy//:thread_pool", ], ) + +cc_binary( + name = "compress_weights", + srcs = [ + "compress_weights.cc", + ], + deps = [ + ":args", + ":gemma_lib", + # "//base", + "//compression:compress", + "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:profiler", + "@hwy//:thread_pool", + ], +) diff --git a/CMakeLists.txt b/CMakeLists.txt index 844ddf2..7cc586b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,7 +80,7 @@ set_target_properties(libgemma PROPERTIES PREFIX "") set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(libgemma PUBLIC ./) target_link_libraries(libgemma hwy hwy_contrib sentencepiece) -target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR}) +target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) @@ -115,3 +115,8 @@ foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) gtest_discover_tests(${TESTNAME}) endforeach () endif() # GEMMA_ENABLE_TESTS + +## Tools + +add_executable(compress_weights compress_weights.cc) +target_link_libraries(compress_weights libgemma hwy hwy_contrib) diff --git a/compress_weights.cc b/compress_weights.cc new file mode 100644 index 0000000..ce4f642 --- /dev/null +++ b/compress_weights.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Command line tool to create compressed weights. + +#include +#include + +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" +// copybara:end + +namespace gcpp { + +struct Args : public ArgsBase { + static constexpr size_t kDefaultNumThreads = ~size_t{0}; + + void ChooseNumThreads() { + if (num_threads == kDefaultNumThreads) { + // This is a rough heuristic, replace with something better in the future. + num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + } + } + + public: + Args(int argc, char* argv[]) { + InitAndParse(argc, argv); + ChooseNumThreads(); + } + + static std::string ToLower(const std::string& text) { + std::string result = text; + std::transform(begin(result), end(result), begin(result), + [](unsigned char c) { return std::tolower(c); }); + return result; + } + + gcpp::Model ModelType() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type_lc.substr(0, 2) == "2b") { + return gcpp::Model::GEMMA_2B; + } else if (model_type_lc.substr(0, 2) == "7b") { + return gcpp::Model::GEMMA_7B; + } else { + HWY_ABORT("Unknown model type %s", model_type_lc.c_str()); + } + } + + // Returns error string or nullptr if OK. + const char* Validate() const { + const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, 7b-it."; + } + if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && + model_type_lc != "2b-it" && model_type_lc != "7b-it") { + return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it."; + } + if (weights.path.empty()) { + return "Missing --weights flag, a file for the uncompressed model."; + } + if (compressed_weights.path.empty()) { + return "Missing --compressed_weights flag, a file for the compressed " + "model."; + } + if (!weights.exists()) { + return "Can't open file specified with --weights flag."; + } + return nullptr; + } + + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location + std::string model_type; + size_t num_threads; + + template + void ForEach(const Visitor& visitor) { + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file.\n" + " Required argument."); + visitor(model_type, "model", std::string(), + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " + " Required argument."); + visitor(compressed_weights, "compressed_weights", Path(), + "Path name where compressed weights file will be written.\n" + " Required argument."); + visitor(num_threads, "num_threads", + kDefaultNumThreads, // see ChooseNumThreads + "Number of threads to use.\n Default = Estimate of the " + "number of suupported concurrent threads.", + 2); + } +}; + +void ShowHelp(gcpp::Args& args) { + std::cerr + << "Usage:\n./compress_weights --weights " + " --model --compressed_weights \n"; + std::cerr << "\n*Arguments*\n\n"; + args.Help(); + std::cerr << "\n"; +} + +void Run(Args& args) { + hwy::ThreadPool pool(args.num_threads); + gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights, + pool); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + gcpp::Args args(argc, argv); + + if (gcpp::HasHelp(argc, argv)) { + ShowHelp(args); + return 0; + } + + if (const char* error = args.Validate()) { + ShowHelp(args); + HWY_ABORT("\nInvalid args: %s", error); + } + + gcpp::Run(args); + + return 0; +} diff --git a/gemma.cc b/gemma.cc index ff58a82..edc5dfd 100644 --- a/gemma.cc +++ b/gemma.cc @@ -116,10 +116,13 @@ hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { checkpoint.path.c_str()); } bool ok = true; + uint64_t total_size = 0; ok &= 1 == fread(&(weights->embedder_input_embedding), sizeof(weights->embedder_input_embedding), 1, fptr); ok &= 1 == fread(&(weights->final_norm_scale), sizeof(weights->final_norm_scale), 1, fptr); + total_size += sizeof(weights->embedder_input_embedding) + + sizeof(weights->final_norm_scale); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { Layer* layer_view = &weights->layers[layer]; ok &= 1 == fread(&layer_view->attn_vec_einsum_w, @@ -134,10 +137,12 @@ hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { sizeof(layer_view->pre_attention_norm_scale), 1, fptr); ok &= 1 == fread(&layer_view->pre_ffw_norm_scale, sizeof(layer_view->pre_ffw_norm_scale), 1, fptr); + total_size += sizeof(*layer_view); } if (!ok) { - HWY_ABORT("Failed to read from %s - might be a directory, or too small?", - checkpoint.path.c_str()); + HWY_ABORT("Failed to read from %s - might be a directory, or too small? " + "expected size: %d kB", checkpoint.path.c_str(), + static_cast(total_size >> 10)); } HWY_ASSERT(0 == fclose(fptr)); return weights; @@ -813,6 +818,47 @@ hwy::AlignedFreeUniquePtr GetCompressedWeightsT( } } +template +void CompressWeights(const Path& weights_path, + const Path& compressed_weights_path, + hwy::ThreadPool& pool) { + if (!std::filesystem::exists(weights_path.path)) { + HWY_ABORT("The model weights file '%s' does not exist.", + weights_path.path.c_str()); + } + + // Allocate compressed weights. + using CWeights = CompressedWeights; + hwy::AlignedFreeUniquePtr c_weights_u8 = + hwy::AllocateAligned(sizeof(CWeights)); + CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); + new (&c_weights->c_layer_ptrs) CompressedLayerPointers(pool); + + // Get weights, compress, and store. + const hwy::AlignedUniquePtr> weights = + LoadWeights(weights_path); + Compressor compressor(pool); + ForEachTensor(weights.get(), *c_weights, compressor); + compressor.WriteAll(pool, compressed_weights_path.path.c_str()); + + c_weights->c_layer_ptrs.~CompressedLayerPointers(); +} + +void CompressWeightsT(gcpp::Model model, const Path& weights, + const Path& compressed_weights, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + CompressWeights(weights, compressed_weights, pool); + break; + case Model::GEMMA_7B: + CompressWeights(weights, compressed_weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -821,6 +867,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_EXPORT(GetCompressedWeightsT); +HWY_EXPORT(CompressWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); @@ -922,5 +969,12 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, stream_token, [](int) { return true; }, gen, runtime_config.verbosity); } +void CompressWeights(gcpp::Model model, const Path& weights, + const Path& compressed_weights, + hwy::ThreadPool& pool) { + HWY_DYNAMIC_DISPATCH(CompressWeightsT)( + model, weights, compressed_weights, pool); +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 8c4cab8..a3caa43 100644 --- a/gemma.h +++ b/gemma.h @@ -98,6 +98,10 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, KVCache& kv_cache, hwy::ThreadPool& pool, const StreamFunc& stream_token, std::mt19937& gen); +void CompressWeights(gcpp::Model model, const Path& weights, + const Path& compressed_weights, + hwy::ThreadPool& pool); + constexpr int EOS_ID = 1; } // namespace gcpp diff --git a/util/args.h b/util/args.h index b9ab985..7b17c99 100644 --- a/util/args.h +++ b/util/args.h @@ -25,27 +25,43 @@ #include "hwy/base.h" // HWY_ABORT +#if defined(_WIN32) +#include +#define F_OK 0 +#define access _access +#else +#include +#endif + namespace gcpp { // Wrapper for strings representing a path name. Differentiates vs. arbitrary // strings and supports shortening for display purposes. struct Path { + Path() {} + explicit Path(const char* p) : path(p) {} + Path& operator=(const char* other) { path = other; return *this; } std::string Shortened() const { - constexpr size_t max_len = 48; - constexpr size_t cut_point = max_len / 2 - 5; - if (path.size() > max_len) { - return std::string(begin(path), begin(path) + cut_point) + " ... " + - std::string(end(path) - cut_point, end(path)); + constexpr size_t kMaxLen = 48; + constexpr size_t kCutPoint = kMaxLen / 2 - 5; + if (path.size() > kMaxLen) { + return std::string(begin(path), begin(path) + kCutPoint) + " ... " + + std::string(end(path) - kCutPoint, end(path)); } if (path.empty()) return "[no path specified]"; return path; } + // Beware, TOCTOU. + bool exists() const { + return (access(path.c_str(), F_OK) == 0); + } + std::string path; };