mirror of https://github.com/google/gemma.cpp.git
Add standalone tool to compress weights.
Co-authored-by: Eugene Kliuchnikov <eustas@google.com>
This commit is contained in:
parent
93a648926c
commit
b670d43e4f
17
BUILD.bazel
17
BUILD.bazel
|
|
@ -115,3 +115,20 @@ cc_binary(
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "compress_weights",
|
||||||
|
srcs = [
|
||||||
|
"compress_weights.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
# "//base",
|
||||||
|
"//compression:compress",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
"@hwy//:profiler",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ set_target_properties(libgemma PROPERTIES PREFIX "")
|
||||||
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||||
target_include_directories(libgemma PUBLIC ./)
|
target_include_directories(libgemma PUBLIC ./)
|
||||||
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
target_link_libraries(libgemma hwy hwy_contrib sentencepiece)
|
||||||
target_include_directories(libgemma PRIVATE ${sentencepiece_SOURCE_DIR})
|
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||||
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
|
||||||
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
|
||||||
|
|
||||||
|
|
@ -115,3 +115,8 @@ foreach (TESTFILE IN LISTS GEMMA_TEST_FILES)
|
||||||
gtest_discover_tests(${TESTNAME})
|
gtest_discover_tests(${TESTNAME})
|
||||||
endforeach ()
|
endforeach ()
|
||||||
endif() # GEMMA_ENABLE_TESTS
|
endif() # GEMMA_ENABLE_TESTS
|
||||||
|
|
||||||
|
## Tools
|
||||||
|
|
||||||
|
add_executable(compress_weights compress_weights.cc)
|
||||||
|
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Command line tool to create compressed weights.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "gemma.h" // Gemma
|
||||||
|
// copybara:end
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
struct Args : public ArgsBase<Args> {
|
||||||
|
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||||
|
|
||||||
|
void ChooseNumThreads() {
|
||||||
|
if (num_threads == kDefaultNumThreads) {
|
||||||
|
// This is a rough heuristic, replace with something better in the future.
|
||||||
|
num_threads = static_cast<size_t>(std::clamp(
|
||||||
|
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Args(int argc, char* argv[]) {
|
||||||
|
InitAndParse(argc, argv);
|
||||||
|
ChooseNumThreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string ToLower(const std::string& text) {
|
||||||
|
std::string result = text;
|
||||||
|
std::transform(begin(result), end(result), begin(result),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Model ModelType() const {
|
||||||
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
|
if (model_type_lc.substr(0, 2) == "2b") {
|
||||||
|
return gcpp::Model::GEMMA_2B;
|
||||||
|
} else if (model_type_lc.substr(0, 2) == "7b") {
|
||||||
|
return gcpp::Model::GEMMA_7B;
|
||||||
|
} else {
|
||||||
|
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() const {
|
||||||
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
|
if (model_type.empty()) {
|
||||||
|
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||||
|
"2b-it, 7b-it.";
|
||||||
|
}
|
||||||
|
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||||
|
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||||
|
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
|
||||||
|
}
|
||||||
|
if (weights.path.empty()) {
|
||||||
|
return "Missing --weights flag, a file for the uncompressed model.";
|
||||||
|
}
|
||||||
|
if (compressed_weights.path.empty()) {
|
||||||
|
return "Missing --compressed_weights flag, a file for the compressed "
|
||||||
|
"model.";
|
||||||
|
}
|
||||||
|
if (!weights.exists()) {
|
||||||
|
return "Can't open file specified with --weights flag.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Path weights; // uncompressed weights file location
|
||||||
|
Path compressed_weights; // compressed weights file location
|
||||||
|
std::string model_type;
|
||||||
|
size_t num_threads;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(weights, "weights", Path(),
|
||||||
|
"Path name of model weights (.sbs) file.\n"
|
||||||
|
" Required argument.");
|
||||||
|
visitor(model_type, "model", std::string(),
|
||||||
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||||
|
" Required argument.");
|
||||||
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
|
"Path name where compressed weights file will be written.\n"
|
||||||
|
" Required argument.");
|
||||||
|
visitor(num_threads, "num_threads",
|
||||||
|
kDefaultNumThreads, // see ChooseNumThreads
|
||||||
|
"Number of threads to use.\n Default = Estimate of the "
|
||||||
|
"number of suupported concurrent threads.",
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ShowHelp(gcpp::Args& args) {
|
||||||
|
std::cerr
|
||||||
|
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||||
|
" --model <model type> --compressed_weights <output path>\n";
|
||||||
|
std::cerr << "\n*Arguments*\n\n";
|
||||||
|
args.Help();
|
||||||
|
std::cerr << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(Args& args) {
|
||||||
|
hwy::ThreadPool pool(args.num_threads);
|
||||||
|
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
|
||||||
|
pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
gcpp::Args args(argc, argv);
|
||||||
|
|
||||||
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
|
ShowHelp(args);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (const char* error = args.Validate()) {
|
||||||
|
ShowHelp(args);
|
||||||
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Run(args);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
58
gemma.cc
58
gemma.cc
|
|
@ -116,10 +116,13 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
||||||
checkpoint.path.c_str());
|
checkpoint.path.c_str());
|
||||||
}
|
}
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
|
uint64_t total_size = 0;
|
||||||
ok &= 1 == fread(&(weights->embedder_input_embedding),
|
ok &= 1 == fread(&(weights->embedder_input_embedding),
|
||||||
sizeof(weights->embedder_input_embedding), 1, fptr);
|
sizeof(weights->embedder_input_embedding), 1, fptr);
|
||||||
ok &= 1 == fread(&(weights->final_norm_scale),
|
ok &= 1 == fread(&(weights->final_norm_scale),
|
||||||
sizeof(weights->final_norm_scale), 1, fptr);
|
sizeof(weights->final_norm_scale), 1, fptr);
|
||||||
|
total_size += sizeof(weights->embedder_input_embedding) +
|
||||||
|
sizeof(weights->final_norm_scale);
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
Layer<TConfig>* layer_view = &weights->layers[layer];
|
Layer<TConfig>* layer_view = &weights->layers[layer];
|
||||||
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
|
ok &= 1 == fread(&layer_view->attn_vec_einsum_w,
|
||||||
|
|
@ -134,10 +137,12 @@ hwy::AlignedUniquePtr<Weights<TConfig>> LoadWeights(const Path& checkpoint) {
|
||||||
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
|
sizeof(layer_view->pre_attention_norm_scale), 1, fptr);
|
||||||
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
|
ok &= 1 == fread(&layer_view->pre_ffw_norm_scale,
|
||||||
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
|
sizeof(layer_view->pre_ffw_norm_scale), 1, fptr);
|
||||||
|
total_size += sizeof(*layer_view);
|
||||||
}
|
}
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
HWY_ABORT("Failed to read from %s - might be a directory, or too small?",
|
HWY_ABORT("Failed to read from %s - might be a directory, or too small? "
|
||||||
checkpoint.path.c_str());
|
"expected size: %d kB", checkpoint.path.c_str(),
|
||||||
|
static_cast<uint32_t>(total_size >> 10));
|
||||||
}
|
}
|
||||||
HWY_ASSERT(0 == fclose(fptr));
|
HWY_ASSERT(0 == fclose(fptr));
|
||||||
return weights;
|
return weights;
|
||||||
|
|
@ -813,6 +818,47 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeightsT(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
void CompressWeights(const Path& weights_path,
|
||||||
|
const Path& compressed_weights_path,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
if (!std::filesystem::exists(weights_path.path)) {
|
||||||
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
|
weights_path.path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate compressed weights.
|
||||||
|
using CWeights = CompressedWeights<TConfig>;
|
||||||
|
hwy::AlignedFreeUniquePtr<uint8_t[]> c_weights_u8 =
|
||||||
|
hwy::AllocateAligned<uint8_t>(sizeof(CWeights));
|
||||||
|
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||||
|
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||||
|
|
||||||
|
// Get weights, compress, and store.
|
||||||
|
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
||||||
|
LoadWeights<TConfig>(weights_path);
|
||||||
|
Compressor compressor(pool);
|
||||||
|
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||||
|
compressor.WriteAll(pool, compressed_weights_path.path.c_str());
|
||||||
|
|
||||||
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||||
|
const Path& compressed_weights,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
switch (model) {
|
||||||
|
case Model::GEMMA_2B:
|
||||||
|
CompressWeights<ConfigGemma2B>(weights, compressed_weights, pool);
|
||||||
|
break;
|
||||||
|
case Model::GEMMA_7B:
|
||||||
|
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
HWY_AFTER_NAMESPACE();
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
@ -821,6 +867,7 @@ HWY_AFTER_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
HWY_EXPORT(GetCompressedWeightsT);
|
HWY_EXPORT(GetCompressedWeightsT);
|
||||||
|
HWY_EXPORT(CompressWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
|
||||||
|
|
@ -922,5 +969,12 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
stream_token, [](int) { return true; }, gen, runtime_config.verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
const Path& compressed_weights,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
HWY_DYNAMIC_DISPATCH(CompressWeightsT)(
|
||||||
|
model, weights, compressed_weights, pool);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
4
gemma.h
4
gemma.h
|
|
@ -98,6 +98,10 @@ void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config,
|
||||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, std::mt19937& gen);
|
const StreamFunc& stream_token, std::mt19937& gen);
|
||||||
|
|
||||||
|
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||||
|
const Path& compressed_weights,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
26
util/args.h
26
util/args.h
|
|
@ -25,27 +25,43 @@
|
||||||
|
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#include <io.h>
|
||||||
|
#define F_OK 0
|
||||||
|
#define access _access
|
||||||
|
#else
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||||
// strings and supports shortening for display purposes.
|
// strings and supports shortening for display purposes.
|
||||||
struct Path {
|
struct Path {
|
||||||
|
Path() {}
|
||||||
|
explicit Path(const char* p) : path(p) {}
|
||||||
|
|
||||||
Path& operator=(const char* other) {
|
Path& operator=(const char* other) {
|
||||||
path = other;
|
path = other;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Shortened() const {
|
std::string Shortened() const {
|
||||||
constexpr size_t max_len = 48;
|
constexpr size_t kMaxLen = 48;
|
||||||
constexpr size_t cut_point = max_len / 2 - 5;
|
constexpr size_t kCutPoint = kMaxLen / 2 - 5;
|
||||||
if (path.size() > max_len) {
|
if (path.size() > kMaxLen) {
|
||||||
return std::string(begin(path), begin(path) + cut_point) + " ... " +
|
return std::string(begin(path), begin(path) + kCutPoint) + " ... " +
|
||||||
std::string(end(path) - cut_point, end(path));
|
std::string(end(path) - kCutPoint, end(path));
|
||||||
}
|
}
|
||||||
if (path.empty()) return "[no path specified]";
|
if (path.empty()) return "[no path specified]";
|
||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Beware, TOCTOU.
|
||||||
|
bool exists() const {
|
||||||
|
return (access(path.c_str(), F_OK) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
std::string path;
|
std::string path;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue