Add standalone tool to compress weights.

Co-authored-by: Eugene Kliuchnikov <eustas@google.com>
This commit is contained in:
Zoltan Szabadka 2024-04-03 12:12:15 +00:00
parent 93a648926c
commit b670d43e4f
6 changed files with 252 additions and 8 deletions

View File

@ -115,3 +115,20 @@ cc_binary(
"@hwy//:thread_pool",
],
)
cc_binary(
name = "compress_weights",
srcs = [
"compress_weights.cc",
],
deps = [
":args",
":gemma_lib",
# "//base",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
],
)

View File

@ -80,7 +80,7 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(libgemma PUBLIC ./)
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
@ -115,3 +115,8 @@ foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
gtest_discover_tests(${TESTNAME})
endforeach ()
endif() # GEMMA_ENABLE_TESTS
## Tools
add_executable(compress_weights compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib)

148
compress_weights.cc Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line tool to create compressed weights.
#include <iostream>
#include <string>
// copybara:import_next_line:gemma_cpp
#include "gemma.h" // Gemma
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h"
// copybara:end
namespace gcpp {
struct Args : public ArgsBase<Args> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
}
}
public:
Args(int argc, char* argv[]) {
InitAndParse(argc, argv);
ChooseNumThreads();
}
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
return result;
}
gcpp::Model ModelType() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc.substr(0, 2) == "2b") {
return gcpp::Model::GEMMA_2B;
} else if (model_type_lc.substr(0, 2) == "7b") {
return gcpp::Model::GEMMA_7B;
} else {
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
}
}
// Returns error string or nullptr if OK.
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type.empty()) {
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
"2b-it, 7b-it.";
}
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model.";
}
if (compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
if (!weights.exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;
}
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type;
size_t num_threads;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n"
" Required argument.");
visitor(model_type, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights file will be written.\n"
" Required argument.");
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use.\n Default = Estimate of the "
"number of suupported concurrent threads.",
2);
}
};
void ShowHelp(gcpp::Args& args) {
std::cerr
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
" --model <model type> --compressed_weights <output path>\n";
std::cerr << "\n*Arguments*\n\n";
args.Help();
std::cerr << "\n";
}
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
pool);
}
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::Args args(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
ShowHelp(args);
return 0;
}
if (const char* error = args.Validate()) {
ShowHelp(args);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(args);
return 0;
}

View File

@ -116,10 +116,13 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
checkpoint.path.c_str());
}
bool ok = true;
uint64_t total_size = 0;
ok &= 1 == fread(&(weights->embedder_input_embedding),
sizeof(weights->embedder_input_embedding), 1, fptr);
ok &= 1 == fread(&(weights->final_norm_scale),
sizeof(weights->final_norm_scale), 1, fptr);
total_size += sizeof(weights->embedder_input_embedding) +
sizeof(weights->final_norm_scale);
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
Layer<TConfig>* layer_view = &weights->layers[layer];
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
@ -134,10 +137,12 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
total_size += sizeof(*layer_view);
}
if (!ok) {
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
checkpoint.path.c_str());
HWY_ABORT("Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB", checkpoint.path.c_str(),
static_cast<uint32_t>(total_size >> 10));
}
HWY_ASSERT(0 == fclose(fptr));
return weights;
@ -813,6 +818,47 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
}
}
template <class TConfig>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path,
hwy::ThreadPool& pool) {
if (!std::filesystem::exists(weights_path.path)) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
}
// Allocate compressed weights.
using CWeights = CompressedWeights<TConfig>;
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
// Get weights, compress, and store.
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
LoadWeights<TConfig>(weights_path);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
}
void CompressWeightsT(gcpp::Model model, const Path& weights,
const Path& compressed_weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
break;
case Model::GEMMA_7B:
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
@ -821,6 +867,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(CompressWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
@ -922,5 +969,12 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
}
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights,
hwy::ThreadPool& pool) {
HWY_DYNAMIC_DISPATCH(CompressWeightsT)(
model, weights, compressed_weights, pool);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -98,6 +98,10 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, std::mt19937& gen);
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights,
hwy::ThreadPool& pool);
constexpr int EOS_ID = 1;
} // namespace gcpp

View File

@ -25,27 +25,43 @@
#include "hwy/base.h" // HWY_ABORT
#if defined(_WIN32)
#include <io.h>
#define F_OK 0
#define access _access
#else
#include <unistd.h>
#endif
namespace gcpp {
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
Path() {}
explicit Path(const char* p) : path(p) {}
Path& operator=(const char* other) {
path = other;
return *this;
}
std::string Shortened() const {
constexpr size_t max_len = 48;
constexpr size_t cut_point = max_len / 2 - 5;
if (path.size() > max_len) {
return std::string(begin(path), begin(path) + cut_point) + " ... " +
std::string(end(path) - cut_point, end(path));
constexpr size_t kMaxLen = 48;
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
if (path.size() > kMaxLen) {
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
std::string(end(path) - kCutPoint, end(path));
}
if (path.empty()) return "[no path specified]";
return path;
}
// Beware, TOCTOU.
bool exists() const {
return (access(path.c_str(), F_OK) == 0);
}
std::string path;
};