mirror of https://github.com/google/gemma.cpp.git
Merge branch 'dev' into feature/ISS-60/implement-self-extend
This commit is contained in:
commit
51a708e957
20
BUILD.bazel
20
BUILD.bazel
|
|
@ -388,6 +388,26 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "gemma_batch_bench",
|
||||||
|
srcs = ["evals/gemma_batch_bench.cc"],
|
||||||
|
# Requires model files
|
||||||
|
tags = [
|
||||||
|
"local",
|
||||||
|
"manual",
|
||||||
|
"no_tap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":benchmark_helper",
|
||||||
|
":common",
|
||||||
|
":gemma_lib",
|
||||||
|
":tokenizer",
|
||||||
|
"@googletest//:gtest_main",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = ["gemma/run.cc"],
|
srcs = ["gemma/run.cc"],
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,9 @@ install(TARGETS gemma DESTINATION bin)
|
||||||
add_executable(single_benchmark evals/benchmark.cc)
|
add_executable(single_benchmark evals/benchmark.cc)
|
||||||
target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
target_link_libraries(single_benchmark libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
|
add_executable(gemma_batch_bench evals/gemma_batch_bench.cc)
|
||||||
|
target_link_libraries(gemma_batch_bench libgemma hwy hwy_contrib nlohmann_json::nlohmann_json)
|
||||||
|
|
||||||
add_executable(benchmarks evals/benchmarks.cc)
|
add_executable(benchmarks evals/benchmarks.cc)
|
||||||
target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark)
|
target_link_libraries(benchmarks libgemma hwy hwy_contrib nlohmann_json::nlohmann_json benchmark)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -690,12 +690,13 @@ class Compressor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
|
||||||
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
const BlobError err = writer_.WriteAll(pool, blob_filename);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
|
||||||
blob_filename.path.c_str(), err);
|
blob_filename.path.c_str(), err);
|
||||||
}
|
}
|
||||||
|
return err;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ pybind_extension(
|
||||||
deps = [
|
deps = [
|
||||||
":compression_clif_aux",
|
":compression_clif_aux",
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
|
"//compression:sfp",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,6 +39,7 @@ py_test(
|
||||||
deps = [
|
deps = [
|
||||||
":compression",
|
":compression",
|
||||||
"//testing/pybase",
|
"//testing/pybase",
|
||||||
|
"//python:configs",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,9 @@ class WriterInterface {
|
||||||
public:
|
public:
|
||||||
virtual ~WriterInterface() = default;
|
virtual ~WriterInterface() = default;
|
||||||
|
|
||||||
virtual void Insert(std::string name, absl::Span<const float> weights) = 0;
|
virtual void Insert(std::string name, absl::Span<const float> weights,
|
||||||
|
Type type) = 0;
|
||||||
|
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
|
||||||
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
|
||||||
virtual void InsertBfloat16(std::string name,
|
virtual void InsertBfloat16(std::string name,
|
||||||
absl::Span<const float> weights) = 0;
|
absl::Span<const float> weights) = 0;
|
||||||
|
|
@ -39,7 +41,7 @@ class WriterInterface {
|
||||||
absl::Span<const float> weights) = 0;
|
absl::Span<const float> weights) = 0;
|
||||||
virtual void AddScales(const std::vector<float>& scales) = 0;
|
virtual void AddScales(const std::vector<float>& scales) = 0;
|
||||||
|
|
||||||
virtual void Write(std::string path) = 0;
|
virtual int Write(std::string path) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -67,7 +69,27 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
public:
|
public:
|
||||||
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
SbsWriterImpl() : pool_(0), compressor_(pool_) {}
|
||||||
|
|
||||||
void Insert(std::string name, absl::Span<const float> weights) override {
|
void Insert(std::string name, absl::Span<const float> weights,
|
||||||
|
Type type) override {
|
||||||
|
switch (type) {
|
||||||
|
case Type::kSFP:
|
||||||
|
AllocateAndCompress<SfpStream>(name, weights);
|
||||||
|
break;
|
||||||
|
case Type::kNUQ:
|
||||||
|
AllocateAndCompress<NuqStream>(name, weights);
|
||||||
|
break;
|
||||||
|
case Type::kBF16:
|
||||||
|
AllocateAndCompress<BF16>(name, weights);
|
||||||
|
break;
|
||||||
|
case Type::kF32:
|
||||||
|
AllocateAndCompress<float>(name, weights);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Unsupported type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertSfp(std::string name, absl::Span<const float> weights) override {
|
||||||
AllocateAndCompress<SfpStream>(name, weights);
|
AllocateAndCompress<SfpStream>(name, weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -90,8 +112,8 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
compressor_.AddScales(scales_.data(), scales_.size());
|
compressor_.AddScales(scales_.data(), scales_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Write(std::string path) override {
|
int Write(std::string path) override {
|
||||||
compressor_.WriteAll(pool_, gcpp::Path(path));
|
return compressor_.WriteAll(pool_, gcpp::Path(path));
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool pool_;
|
hwy::ThreadPool pool_;
|
||||||
|
|
@ -115,8 +137,12 @@ HWY_EXPORT(NewSbsWriter);
|
||||||
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
||||||
SbsWriter::~SbsWriter() = default;
|
SbsWriter::~SbsWriter() = default;
|
||||||
|
|
||||||
void SbsWriter::Insert(std::string name, absl::Span<const float> weights) {
|
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
|
||||||
impl_->Insert(name, weights);
|
Type type) {
|
||||||
|
impl_->Insert(name, weights, type);
|
||||||
|
}
|
||||||
|
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
|
||||||
|
impl_->InsertSfp(name, weights);
|
||||||
}
|
}
|
||||||
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
|
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
|
||||||
impl_->InsertNUQ(name, weights);
|
impl_->InsertNUQ(name, weights);
|
||||||
|
|
@ -132,7 +158,7 @@ void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
|
||||||
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
void SbsWriter::AddScales(const std::vector<float>& scales) {
|
||||||
impl_->AddScales(scales);
|
impl_->AddScales(scales);
|
||||||
}
|
}
|
||||||
void SbsWriter::Write(std::string path) { impl_->Write(path); }
|
int SbsWriter::Write(std::string path) { return impl_->Write(path); }
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "compression/shared.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -16,13 +17,14 @@ class SbsWriter {
|
||||||
SbsWriter();
|
SbsWriter();
|
||||||
~SbsWriter();
|
~SbsWriter();
|
||||||
|
|
||||||
void Insert(std::string name, absl::Span<const float> weights);
|
void Insert(std::string name, absl::Span<const float> weights, Type type);
|
||||||
|
void InsertSfp(std::string name, absl::Span<const float> weights);
|
||||||
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
void InsertNUQ(std::string name, absl::Span<const float> weights);
|
||||||
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
void InsertBfloat16(std::string name, absl::Span<const float> weights);
|
||||||
void InsertFloat(std::string name, absl::Span<const float> weights);
|
void InsertFloat(std::string name, absl::Span<const float> weights);
|
||||||
void AddScales(const std::vector<float>& scales);
|
void AddScales(const std::vector<float>& scales);
|
||||||
|
|
||||||
void Write(std::string path);
|
int Write(std::string path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Isolates Highway-dispatched types and other internals from CLIF.
|
// Isolates Highway-dispatched types and other internals from CLIF.
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
#include <exception>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "compression/python/compression_clif_aux.h"
|
#include "compression/python/compression_clif_aux.h"
|
||||||
|
#include "compression/shared.h"
|
||||||
#include "pybind11/numpy.h"
|
#include "pybind11/numpy.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
|
|
@ -22,6 +22,15 @@ void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
|
||||||
}
|
}
|
||||||
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
|
||||||
}
|
}
|
||||||
|
template <auto Func>
|
||||||
|
void wrap_span_typed(SbsWriter& writer, std::string name,
|
||||||
|
py::array_t<float> data, gcpp::Type type) {
|
||||||
|
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
|
||||||
|
throw std::domain_error("Input array must be 1D and densely packed.");
|
||||||
|
}
|
||||||
|
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
|
||||||
|
type);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(compression, m) {
|
PYBIND11_MODULE(compression, m) {
|
||||||
|
|
@ -29,7 +38,8 @@ PYBIND11_MODULE(compression, m) {
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
// NOTE: Individual compression backends may impose constraints on the
|
// NOTE: Individual compression backends may impose constraints on the
|
||||||
// array length, such as a minimum of (say) 32 elements.
|
// array length, such as a minimum of (say) 32 elements.
|
||||||
.def("insert", wrap_span<&SbsWriter::Insert>)
|
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
|
||||||
|
.def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>)
|
||||||
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
|
||||||
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
|
||||||
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import numpy as np
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from compression.python import compression
|
from compression.python import compression
|
||||||
|
from gemma.python import configs
|
||||||
|
|
||||||
|
|
||||||
class CompressionTest(unittest.TestCase):
|
class CompressionTest(unittest.TestCase):
|
||||||
|
|
@ -13,9 +14,11 @@ class CompressionTest(unittest.TestCase):
|
||||||
|
|
||||||
writer = compression.SbsWriter()
|
writer = compression.SbsWriter()
|
||||||
writer.insert(
|
writer.insert(
|
||||||
"foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32)
|
"foo",
|
||||||
|
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
|
||||||
|
configs.Type.kSFP,
|
||||||
)
|
)
|
||||||
writer.insert(
|
writer.insert_sfp(
|
||||||
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
|
||||||
)
|
)
|
||||||
writer.insert_nuq(
|
writer.insert_nuq(
|
||||||
|
|
@ -27,7 +30,7 @@ class CompressionTest(unittest.TestCase):
|
||||||
writer.insert_float(
|
writer.insert_float(
|
||||||
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
|
||||||
)
|
)
|
||||||
writer.write(temp_file.full_path)
|
self.assertEqual(writer.write(temp_file.full_path), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "evals/benchmark_helper.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
// This test can be run manually with the downloaded gemma weights.
|
||||||
|
// To run the test, pass the following flags:
|
||||||
|
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||||
|
// It should pass for the following models:
|
||||||
|
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
||||||
|
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Shared state. Requires argc/argv, so construct in main and use the same raw
|
||||||
|
// pointer approach as in benchmarks.cc. Note that the style guide forbids
|
||||||
|
// non-local static variables with dtors.
|
||||||
|
GemmaEnv* s_env = nullptr;
|
||||||
|
|
||||||
|
class GemmaTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
std::vector<std::string> BatchGemmaReply(
|
||||||
|
const std::vector<std::string>& inputs) {
|
||||||
|
s_env->SetMaxGeneratedTokens(64);
|
||||||
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
|
s_env->MutableConfig().verbosity = 5;
|
||||||
|
std::vector<std::string> replies;
|
||||||
|
// Using the turn structure worsens results sometimes.
|
||||||
|
// However, some models need the turn structure to work.
|
||||||
|
// It would be good to make these tests more consistent.
|
||||||
|
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||||
|
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||||
|
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||||
|
replies.push_back(result.response);
|
||||||
|
}
|
||||||
|
return replies;
|
||||||
|
}
|
||||||
|
// Otherwise, do not use turn structure.
|
||||||
|
std::vector<std::vector<int>> prompts_vector;
|
||||||
|
prompts_vector.reserve(inputs.size());
|
||||||
|
for (const auto& input_string : inputs) {
|
||||||
|
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||||
|
}
|
||||||
|
std::vector<PromptTokens> prompt_spans;
|
||||||
|
for (const auto& prompt : prompts_vector) {
|
||||||
|
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
|
}
|
||||||
|
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
|
||||||
|
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
|
||||||
|
replies.push_back(result.response);
|
||||||
|
}
|
||||||
|
return replies;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
||||||
|
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||||
|
|
||||||
|
std::vector<std::string> inputs;
|
||||||
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
inputs.push_back(kQA[i]);
|
||||||
|
}
|
||||||
|
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||||
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
std::string response = responses.at(i);
|
||||||
|
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||||
|
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||||
|
s_env->MutableConfig().verbosity = 5;
|
||||||
|
|
||||||
|
static std::vector<std::string> kQA = {
|
||||||
|
{"Write me a poem about Australia?"},
|
||||||
|
{"What's the history of Denmark?"},
|
||||||
|
{"Write me a comedy story about the USA."},
|
||||||
|
{"Teach me about GPU programming."},
|
||||||
|
{"Write me a story about the moon."},
|
||||||
|
{"Write me a story about the universe."},
|
||||||
|
{"Write a poem about planet earth."},
|
||||||
|
{"Tell me more about olympic sports."},
|
||||||
|
{"How would you describe Washington State?"},
|
||||||
|
{"Write me a story about Silicon Valley."},
|
||||||
|
{"Write me about your best friend."},
|
||||||
|
{"How would you describe a unicorn?"},
|
||||||
|
{"Tell me about world war history."},
|
||||||
|
{"Tell me about Google."},
|
||||||
|
{"Explain to me how to use Google Maps."},
|
||||||
|
{"Explain to me how AI works."},
|
||||||
|
{"Write me a poem about France."},
|
||||||
|
{"What's the history of Great Britain?"},
|
||||||
|
{"Write me a comedy story about Florida."},
|
||||||
|
{"Teach me about dynamic programming."},
|
||||||
|
{"Write me a story about Jupiter."},
|
||||||
|
{"Write me a story about space ships."},
|
||||||
|
{"Write a poem about some random planet."},
|
||||||
|
{"Tell me more about team sports."},
|
||||||
|
{"How would you describe Michigan State?"},
|
||||||
|
{"Write me a story about Europe."},
|
||||||
|
{"Write me about your best colleague."},
|
||||||
|
{"How would you describe a horse?"},
|
||||||
|
{"Tell me about World War 2."},
|
||||||
|
{"Please share some good cooking tips."},
|
||||||
|
{"Tell me about space travel."},
|
||||||
|
{"Explain to me how electric cars work."},
|
||||||
|
};
|
||||||
|
static const size_t kNum = kQA.size();
|
||||||
|
GenerateTokens(kQA, kNum);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
|
gcpp::s_env = &env;
|
||||||
|
|
||||||
|
testing::InitGoogleTest(&argc, argv);
|
||||||
|
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -49,6 +49,7 @@ static ModelConfig ConfigGemma2_27B() {
|
||||||
.heads = 32,
|
.heads = 32,
|
||||||
.kv_heads = 16,
|
.kv_heads = 16,
|
||||||
.qkv_dim = 128,
|
.qkv_dim = 128,
|
||||||
|
.optimized_gating = false,
|
||||||
.post_norm = PostNormType::Scale};
|
.post_norm = PostNormType::Scale};
|
||||||
config.layer_configs = {46, layer_config};
|
config.layer_configs = {46, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
|
@ -70,6 +71,7 @@ static ModelConfig ConfigGemma2_9B() {
|
||||||
.heads = 16,
|
.heads = 16,
|
||||||
.kv_heads = 8,
|
.kv_heads = 8,
|
||||||
.qkv_dim = 256,
|
.qkv_dim = 256,
|
||||||
|
.optimized_gating = false,
|
||||||
.post_norm = PostNormType::Scale};
|
.post_norm = PostNormType::Scale};
|
||||||
config.layer_configs = {42, layer_config};
|
config.layer_configs = {42, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
|
@ -91,6 +93,7 @@ static ModelConfig ConfigGemma2_2B() {
|
||||||
.heads = 8,
|
.heads = 8,
|
||||||
.kv_heads = 4,
|
.kv_heads = 4,
|
||||||
.qkv_dim = 256,
|
.qkv_dim = 256,
|
||||||
|
.optimized_gating = false,
|
||||||
.post_norm = PostNormType::Scale};
|
.post_norm = PostNormType::Scale};
|
||||||
config.layer_configs = {26, layer_config};
|
config.layer_configs = {26, layer_config};
|
||||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
|
|
||||||
|
|
@ -374,7 +374,6 @@ void AssertMatch(const ModelConfig& config) {
|
||||||
}
|
}
|
||||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||||
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
||||||
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
|
|
||||||
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
||||||
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
||||||
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
||||||
|
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_applicable_licenses = [
|
|
||||||
"//:license", # Placeholder comment, do not modify
|
|
||||||
],
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
pybind_extension(
|
|
||||||
name = "configs",
|
|
||||||
srcs = ["configs.cc"],
|
|
||||||
deps = [
|
|
||||||
"//:common",
|
|
||||||
"//compression:sfp",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
@ -1,156 +0,0 @@
|
||||||
#include "gemma/configs.h"
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
#include <pybind11/stl.h>
|
|
||||||
|
|
||||||
#include "compression/shared.h"
|
|
||||||
#include "gemma/tensor_index.h"
|
|
||||||
#include "pybind11/cast.h"
|
|
||||||
|
|
||||||
using gcpp::ActivationType;
|
|
||||||
using gcpp::LayerAttentionType;
|
|
||||||
using gcpp::LayerConfig;
|
|
||||||
using gcpp::Model;
|
|
||||||
using gcpp::ModelConfig;
|
|
||||||
using gcpp::ModelTraining;
|
|
||||||
using gcpp::PostNormType;
|
|
||||||
using gcpp::PostQKType;
|
|
||||||
using gcpp::QueryScaleType;
|
|
||||||
using gcpp::ResidualType;
|
|
||||||
using gcpp::TensorIndex;
|
|
||||||
using gcpp::TensorInfo;
|
|
||||||
using gcpp::Type;
|
|
||||||
|
|
||||||
namespace pybind11 {
|
|
||||||
|
|
||||||
PYBIND11_MODULE(configs, py_module) {
|
|
||||||
enum_<ModelTraining>(py_module, "ModelTraining")
|
|
||||||
.value("GEMMA_IT", ModelTraining::GEMMA_IT)
|
|
||||||
.value("GEMMA_PT", ModelTraining::GEMMA_PT)
|
|
||||||
.value("PALIGEMMA", ModelTraining::PALIGEMMA);
|
|
||||||
|
|
||||||
enum_<Type>(py_module, "Type")
|
|
||||||
.value("kUnknown", Type::kUnknown)
|
|
||||||
.value("kF32", Type::kF32)
|
|
||||||
.value("kBF16", Type::kBF16)
|
|
||||||
.value("kSFP", Type::kSFP)
|
|
||||||
.value("kNUQ", Type::kNUQ)
|
|
||||||
.value("kF64", Type::kF64)
|
|
||||||
.value("kC64", Type::kC64)
|
|
||||||
.value("kU128", Type::kU128);
|
|
||||||
|
|
||||||
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
|
|
||||||
.value("kGemma", LayerAttentionType::kGemma)
|
|
||||||
.value("kGriffinRecurrentBlock",
|
|
||||||
LayerAttentionType::kGriffinRecurrentBlock)
|
|
||||||
.value("kVit", LayerAttentionType::kVit);
|
|
||||||
|
|
||||||
enum_<PostNormType>(py_module, "PostNormType")
|
|
||||||
.value("NoPostNorm", PostNormType::None)
|
|
||||||
.value("Scale", PostNormType::Scale);
|
|
||||||
|
|
||||||
enum_<PostQKType>(py_module, "PostQKType")
|
|
||||||
.value("Rope", PostQKType::Rope)
|
|
||||||
.value("HalfRope", PostQKType::HalfRope);
|
|
||||||
|
|
||||||
enum_<ActivationType>(py_module, "ActivationType")
|
|
||||||
.value("Gelu", ActivationType::Gelu);
|
|
||||||
|
|
||||||
enum_<QueryScaleType>(py_module, "QueryScaleType")
|
|
||||||
.value("SqrtKeySize", QueryScaleType::SqrtKeySize)
|
|
||||||
.value("SqrtModelDimDivNumHeads",
|
|
||||||
QueryScaleType::SqrtModelDimDivNumHeads);
|
|
||||||
|
|
||||||
enum_<ResidualType>(py_module, "ResidualType")
|
|
||||||
.value("Add", ResidualType::Add);
|
|
||||||
|
|
||||||
enum_<Model>(py_module, "Model")
|
|
||||||
.value("UNKNOWN", Model::UNKNOWN)
|
|
||||||
.value("GEMMA_2B", Model::GEMMA_2B)
|
|
||||||
.value("GEMMA_7B", Model::GEMMA_7B)
|
|
||||||
.value("GEMMA2_9B", Model::GEMMA2_9B)
|
|
||||||
.value("GEMMA2_27B", Model::GEMMA2_27B)
|
|
||||||
.value("GRIFFIN_2B", Model::GRIFFIN_2B)
|
|
||||||
.value("GEMMA_TINY", Model::GEMMA_TINY)
|
|
||||||
.value("GEMMA2_2B", Model::GEMMA2_2B)
|
|
||||||
.value("PALIGEMMA_224", Model::PALIGEMMA_224);
|
|
||||||
|
|
||||||
class_<TensorInfo>(py_module, "TensorInfo")
|
|
||||||
.def(init())
|
|
||||||
.def_readwrite("name", &TensorInfo::name)
|
|
||||||
.def_readwrite("source_names", &TensorInfo::source_names)
|
|
||||||
.def_readwrite("preshape", &TensorInfo::preshape)
|
|
||||||
.def_readwrite("axes", &TensorInfo::axes)
|
|
||||||
.def_readwrite("shape", &TensorInfo::shape)
|
|
||||||
.def_readwrite("concat_names", &TensorInfo::concat_names)
|
|
||||||
.def_readwrite("concat_axis", &TensorInfo::concat_axis)
|
|
||||||
.def_readwrite("min_size", &TensorInfo::min_size)
|
|
||||||
.def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus)
|
|
||||||
.def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims);
|
|
||||||
|
|
||||||
class_<TensorIndex>(py_module, "TensorIndex")
|
|
||||||
.def(init<const ModelConfig&, int, int, bool>())
|
|
||||||
.def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path"));
|
|
||||||
|
|
||||||
class_<LayerConfig>(py_module, "LayerConfig")
|
|
||||||
.def(init())
|
|
||||||
.def_readwrite("model_dim", &LayerConfig::model_dim)
|
|
||||||
.def_readwrite("griffin_dim", &LayerConfig::griffin_dim)
|
|
||||||
.def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim)
|
|
||||||
.def_readwrite("heads", &LayerConfig::heads)
|
|
||||||
.def_readwrite("kv_heads", &LayerConfig::kv_heads)
|
|
||||||
.def_readwrite("qkv_dim", &LayerConfig::qkv_dim)
|
|
||||||
.def_readwrite("conv1d_width", &LayerConfig::conv1d_width)
|
|
||||||
.def_readwrite("ff_biases", &LayerConfig::ff_biases)
|
|
||||||
.def_readwrite("softmax_attn_output_biases",
|
|
||||||
&LayerConfig::softmax_attn_output_biases)
|
|
||||||
.def_readwrite("optimized_gating", &LayerConfig::optimized_gating)
|
|
||||||
.def_readwrite("post_norm", &LayerConfig::post_norm)
|
|
||||||
.def_readwrite("type", &LayerConfig::type)
|
|
||||||
.def_readwrite("activation", &LayerConfig::activation)
|
|
||||||
.def_readwrite("post_qk", &LayerConfig::post_qk);
|
|
||||||
|
|
||||||
class_<ModelConfig>(py_module, "ModelConfig")
|
|
||||||
.def(init())
|
|
||||||
.def_readwrite("model_name", &ModelConfig::model_name)
|
|
||||||
.def_readwrite("model", &ModelConfig::model)
|
|
||||||
.def_readwrite("training", &ModelConfig::training)
|
|
||||||
.def_readwrite("weight", &ModelConfig::weight)
|
|
||||||
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
|
||||||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
|
||||||
.def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim)
|
|
||||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
|
||||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
|
||||||
.def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len)
|
|
||||||
.def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales)
|
|
||||||
.def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales)
|
|
||||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
|
||||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
|
||||||
.def_readwrite("absolute_pe", &ModelConfig::absolute_pe)
|
|
||||||
.def_readwrite("use_local_attention", &ModelConfig::use_local_attention)
|
|
||||||
.def_readwrite("query_scale", &ModelConfig::query_scale)
|
|
||||||
.def_readwrite("layer_configs", &ModelConfig::layer_configs)
|
|
||||||
.def_readwrite("attention_window_sizes",
|
|
||||||
&ModelConfig::attention_window_sizes)
|
|
||||||
.def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs)
|
|
||||||
.def_readwrite("scale_names", &ModelConfig::scale_names)
|
|
||||||
.def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups)
|
|
||||||
.def_readwrite("model_family_version", &ModelConfig::model_family_version)
|
|
||||||
.def_readwrite("patch_width", &ModelConfig::patch_width)
|
|
||||||
.def_readwrite("image_size", &ModelConfig::image_size)
|
|
||||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
|
||||||
arg("layer_config"))
|
|
||||||
.def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"),
|
|
||||||
arg("debug"));
|
|
||||||
|
|
||||||
// Returns the config for the given model.
|
|
||||||
py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model"));
|
|
||||||
|
|
||||||
// Returns the model for the given config, if it matches any standard model.
|
|
||||||
py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config"));
|
|
||||||
|
|
||||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
|
||||||
py_module.def("vit_config", &gcpp::VitConfig, arg("config"));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace pybind11
|
|
||||||
Loading…
Reference in New Issue