mirror of https://github.com/google/gemma.cpp.git
Add standalone tool to compress weights.
Co-authored-by: Eugene Kliuchnikov <eustas@google.com>
This commit is contained in:
parent
93a648926c
commit
b670d43e4f
17
BUILD.bazel
17
BUILD.bazel
|
|
@ -115,3 +115,20 @@ cc_binary(
|
|||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "compress_weights",
|
||||
srcs = [
|
||||
"compress_weights.cc",
|
||||
],
|
||||
deps = [
|
||||
":args",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
|
|||
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(libgemma PUBLIC ./)
|
||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
||||
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
||||
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||
|
||||
|
|
@ -115,3 +115,8 @@ foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
|||
gtest_discover_tests(${TESTNAME})
|
||||
endforeach ()
|
||||
endif() # GEMMA_ENABLE_TESTS
|
||||
|
||||
## Tools
|
||||
|
||||
add_executable(compress_weights compress_weights.cc)
|
||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line tool to create compressed weights.
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "gemma.h" // Gemma
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h"
|
||||
// copybara:end
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct Args : public ArgsBase<Args> {
|
||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||
|
||||
void ChooseNumThreads() {
|
||||
if (num_threads == kDefaultNumThreads) {
|
||||
// This is a rough heuristic, replace with something better in the future.
|
||||
num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Args(int argc, char* argv[]) {
|
||||
InitAndParse(argc, argv);
|
||||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return result;
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc.substr(0, 2) == "2b") {
|
||||
return gcpp::Model::GEMMA_2B;
|
||||
} else if (model_type_lc.substr(0, 2) == "7b") {
|
||||
return gcpp::Model::GEMMA_7B;
|
||||
} else {
|
||||
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type.empty()) {
|
||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||
"2b-it, 7b-it.";
|
||||
}
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the uncompressed model.";
|
||||
}
|
||||
if (compressed_weights.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
if (!weights.exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type;
|
||||
size_t num_threads;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n"
|
||||
" Required argument.");
|
||||
visitor(model_type, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights file will be written.\n"
|
||||
" Required argument.");
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of suupported concurrent threads.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
void ShowHelp(gcpp::Args& args) {
|
||||
std::cerr
|
||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||
" --model <model type> --compressed_weights <output path>\n";
|
||||
std::cerr << "\n*Arguments*\n\n";
|
||||
args.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
|
||||
pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::Args args(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
ShowHelp(args);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = args.Validate()) {
|
||||
ShowHelp(args);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(args);
|
||||
|
||||
return 0;
|
||||
}
|
||||
58
gemma.cc
58
gemma.cc
|
|
@ -116,10 +116,13 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
|||
checkpoint.path.c_str());
|
||||
}
|
||||
bool ok = true;
|
||||
uint64_t total_size = 0;
|
||||
ok &= 1 == fread(&(weights->embedder_input_embedding),
|
||||
sizeof(weights->embedder_input_embedding), 1, fptr);
|
||||
ok &= 1 == fread(&(weights->final_norm_scale),
|
||||
sizeof(weights->final_norm_scale), 1, fptr);
|
||||
total_size += sizeof(weights->embedder_input_embedding) +
|
||||
sizeof(weights->final_norm_scale);
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
Layer<TConfig>* layer_view = &weights->layers[layer];
|
||||
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
|
||||
|
|
@ -134,10 +137,12 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
|||
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
|
||||
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
|
||||
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
|
||||
total_size += sizeof(*layer_view);
|
||||
}
|
||||
if (!ok) {
|
||||
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
|
||||
checkpoint.path.c_str());
|
||||
HWY_ABORT("Failed to read from %s - might be a directory, or too small? "
|
||||
"expected size: %d kB", checkpoint.path.c_str(),
|
||||
static_cast<uint32_t>(total_size >> 10));
|
||||
}
|
||||
HWY_ASSERT(0 == fclose(fptr));
|
||||
return weights;
|
||||
|
|
@ -813,6 +818,47 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path,
|
||||
hwy::ThreadPool& pool) {
|
||||
if (!std::filesystem::exists(weights_path.path)) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<TConfig>;
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
|
||||
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||
|
||||
// Get weights, compress, and store.
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
||||
LoadWeights<TConfig>(weights_path);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
|
||||
|
||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
}
|
||||
|
||||
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -821,6 +867,7 @@ HWY_AFTER_NAMESPACE();
|
|||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(GetCompressedWeightsT);
|
||||
HWY_EXPORT(CompressWeightsT);
|
||||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
|
||||
|
|
@ -922,5 +969,12 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||
}
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DYNAMIC_DISPATCH(CompressWeightsT)(
|
||||
model, weights, compressed_weights, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
4
gemma.h
4
gemma.h
|
|
@ -98,6 +98,10 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
|||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const StreamFunc& stream_token, std::mt19937& gen);
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
26
util/args.h
26
util/args.h
|
|
@ -25,27 +25,43 @@
|
|||
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <io.h>
|
||||
#define F_OK 0
|
||||
#define access _access
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||
// strings and supports shortening for display purposes.
|
||||
struct Path {
|
||||
Path() {}
|
||||
explicit Path(const char* p) : path(p) {}
|
||||
|
||||
Path& operator=(const char* other) {
|
||||
path = other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string Shortened() const {
|
||||
constexpr size_t max_len = 48;
|
||||
constexpr size_t cut_point = max_len / 2 - 5;
|
||||
if (path.size() > max_len) {
|
||||
return std::string(begin(path), begin(path) + cut_point) + " ... " +
|
||||
std::string(end(path) - cut_point, end(path));
|
||||
constexpr size_t kMaxLen = 48;
|
||||
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
||||
if (path.size() > kMaxLen) {
|
||||
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
||||
std::string(end(path) - kCutPoint, end(path));
|
||||
}
|
||||
if (path.empty()) return "[no path specified]";
|
||||
return path;
|
||||
}
|
||||
|
||||
// Beware, TOCTOU.
|
||||
bool exists() const {
|
||||
return (access(path.c_str(), F_OK) == 0);
|
||||
}
|
||||
|
||||
std::string path;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue