Support all weight types in a single binary.

This changes the command line flags, but the default value retains the previous behavior.

Also add a CreateGemma helper to enable extra args without interface changes.

PiperOrigin-RevId: 641266411
This commit is contained in:
Jan Wassenberg 2024-06-07 09:04:06 -07:00 committed by Copybara-Service
parent 24db2ff725
commit f9b390b134
27 changed files with 372 additions and 248 deletions

View File

@ -116,6 +116,7 @@ cc_library(
"gemma/cross_entropy.h", "gemma/cross_entropy.h",
], ],
deps = [ deps = [
":common",
":gemma_lib", ":gemma_lib",
], ],
) )

View File

@ -76,17 +76,6 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
endif() endif()
# Allowable types for WEIGHT_TYPE:
# float - slow, not recommended
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
# SfpStream - 8-bit switched floating point (recommended)
# NuqStream - experimental, work-in-progress
option(WEIGHT_TYPE "Set weight type" "")
if (WEIGHT_TYPE)
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
endif()
FetchContent_GetProperties(sentencepiece) FetchContent_GetProperties(sentencepiece)
## Library Target ## Library Target

View File

@ -105,18 +105,12 @@ the resulting file as `--weights` and the desired .sbs name as the
There are several compile-time flags to be aware of (note these may or may not There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system): be exposed to the build system):
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
`gemma.h`.
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV - `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet. through `CMakeLists.txt` yet.
In the medium term both of these will likely be deprecated in favor of handling In the medium term this will likely be deprecated in favor of handling options
options at runtime - allowing for multiple weight compression schemes in a single at runtime - dynamically resizing the KV cache as needed.
build and dynamically resizes the KV cache as needed.
## Using gemma.cpp as a Library (Advanced) ## Using gemma.cpp as a Library (Advanced)

View File

@ -138,33 +138,16 @@ convenient directory location (e.g. the `build/` directory in this repo).
The build system uses [CMake](https://cmake.org/). To build the gemma inference The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake` runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. Note if you previous ran `cmake` and are from the top-level project directory. Note if you previous ran `cmake` and are
re-running with a different setting, be sure to clean out the `build/` directory re-running with a different setting, be sure to delete all files in the `build/`
with `rm -rf build/*` (warning this will delete any other files in the `build/` directory with `rm -rf build/*`.
directory.
For the 8-bit switched floating point weights (sfp), run cmake with no options:
#### Unix-like Platforms #### Unix-like Platforms
```sh ```sh
cmake -B build cmake -B build
``` ```
**or** if you downloaded bfloat16 weights (any model *without* `-sfp` in the After running `cmake`, you can enter the `build/` directory and run `make` to
name), instead of running cmake with no options as above, run cmake with build the `./gemma` executable:
WEIGHT_TYPE set to [highway's](https://github.com/google/highway)
`hwy::bfloat16_t` type. Alternatively, you can also add
`-DGEMMA_WEIGHT_T=hwy::bfloat16_t` to the C++ compiler flags.
We intend to soon support all weight types without requiring extra flags. Note
that we recommend using `-sfp` weights instead of bfloat16 for faster inference.
```sh
cmake -B build -DWEIGHT_TYPE=hwy::bfloat16_t
```
After running whichever of the above `cmake` invocations that is appropriate for
your weights, you can enter the `build/` directory and run `make` to build the
`./gemma` executable:
```sh ```sh
# Configure `build` directory # Configure `build` directory
@ -221,11 +204,12 @@ You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments: `gemma` has the following required arguments:
| Argument | Description | Example value | Argument | Description | Example value
| ------------- | ---------------------------- | -------------------------- | --------------- | ---------------------------- | -----------------------
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) | `--model` | The model type. | `2b-it` ... (see below)
| `--weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) | `--weights` | The compressed weights file. | `2b-it-sfp.sbs`
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` | `--weight_type` | The compressed weight type. | `sfp`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
`gemma` is invoked as: `gemma` is invoked as:
@ -233,6 +217,7 @@ You can now run `gemma` from inside the `build/` directory.
./gemma \ ./gemma \
--tokenizer [tokenizer file] \ --tokenizer [tokenizer file] \
--weights [compressed weights file] \ --weights [compressed weights file] \
--weight_type [f32 or bf16 or sfp] \
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] --model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
``` ```
@ -245,8 +230,7 @@ Example invocation for the following configuration:
```sh ```sh
./gemma \ ./gemma \
--tokenizer tokenizer.spm \ --tokenizer tokenizer.spm \
--weights 2b-it-sfp.sbs \ --weights 2b-it-sfp.sbs --weight_type sfp --model 2b-it
--model 2b-it
``` ```
### RecurrentGemma ### RecurrentGemma
@ -270,14 +254,12 @@ Step 1, and run the binary as follows:
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** **Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
The most common problem is that `cmake` was built with the wrong weight type and The most common problem is that the `--weight_type` argument does not match that
`gemma` is attempting to load `bfloat16` weights (`2b-it`, `2b-pt`, `7b-it`, of the model file. Revisit step #3 and check which weights you downloaded.
`7b-pt`) using the default switched floating point (sfp) or vice versa. Revisit
step #3 and check that the `cmake` command used to build `gemma` was correct for
the weights that you downloaded.
In the future we will handle model format handling from compile time to runtime Note that we have already moved weight type from a compile-time decision to a
to simplify this. runtime argument. In a subsequent step, we plan to bake this information into
the weights.
**Problems building in Windows / Visual Studio** **Problems building in Windows / Visual Studio**

View File

@ -21,7 +21,6 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <array> #include <array>
#include <cmath> #include <cmath>
@ -44,6 +43,7 @@
#endif #endif
#include "gemma/ops.h" #include "gemma/ops.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {

View File

@ -15,6 +15,11 @@
#include "backprop/backward.h" #include "backprop/backward.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -29,7 +34,6 @@
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig> template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt, void CrossEntropyLossBackwardPass(const Prompt& prompt,
@ -57,11 +61,11 @@ void CrossEntropyLossBackwardPassT(Model model,
// TODO(janwas): use CallFunctorForModel // TODO(janwas): use CallFunctorForModel
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B>( CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, grad, backward, pool); prompt, weights, forward, grad, backward, pool);
break; break;
case Model::GEMMA_TINY: case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny>( CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, grad, backward, pool); prompt, weights, forward, grad, backward, pool);
break; break;
default: default:

View File

@ -15,6 +15,11 @@
#include "backprop/forward.h" #include "backprop/forward.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -29,7 +34,6 @@
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig> template <typename TConfig>
float CrossEntropyLossForwardPass(const Prompt& prompt, float CrossEntropyLossForwardPass(const Prompt& prompt,
@ -51,10 +55,10 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
// TODO(janwas): use CallFunctorForModel // TODO(janwas): use CallFunctorForModel
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B>( return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
prompt, weights, forward, pool); forward, pool);
case Model::GEMMA_TINY: case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny>( return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, pool); prompt, weights, forward, pool);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));

View File

@ -13,18 +13,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <iostream> #include <stddef.h>
#include <string>
#include <limits>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/backward.h" #include "backprop/backward.h"
#include "backprop/forward.h" #include "backprop/forward.h"
#include "backprop/optimizer.h" #include "backprop/optimizer.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "gtest/gtest.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -35,11 +40,17 @@ TEST(OptimizeTest, GradientDescent) {
std::mt19937 gen(42); std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY; Model model_type = Model::GEMMA_TINY;
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool); Type weight_type = Type::kF32;
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool); ByteStorageT grad =
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool); CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type); ByteStorageT grad_m =
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type); CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad_v =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
KVCache kv_cache = KVCache::Create(model_type); KVCache kv_cache = KVCache::Create(model_type);
size_t max_tokens = 32; size_t max_tokens = 32;
size_t max_generated_tokens = 16; size_t max_generated_tokens = 16;
@ -47,7 +58,7 @@ TEST(OptimizeTest, GradientDescent) {
int verbosity = 0; int verbosity = 0;
const auto accept_token = [](int) { return true; }; const auto accept_token = [](int) { return true; };
Gemma gemma(GemmaTokenizer(), model_type, pool); Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
const auto generate = [&](const std::vector<int>& prompt) { const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply; std::vector<int> reply;
@ -76,12 +87,14 @@ TEST(OptimizeTest, GradientDescent) {
return ok; return ok;
}; };
RandInitWeights(model_type, gemma.Weights(), pool, gen); RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool); CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool); pool);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
LogWeightStats(model_type, gemma.Weights()); LogWeightStats(model_type, weight_type, gemma.Weights());
constexpr size_t kBatchSize = 8; constexpr size_t kBatchSize = 8;
const float alpha = 0.001f; const float alpha = 0.001f;
@ -96,7 +109,8 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok; size_t num_ok;
for (; steps < 1000000; ++steps) { for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42); std::mt19937 sgen(42);
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool); CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
pool);
float total_loss = 0.0f; float total_loss = 0.0f;
num_ok = 0; num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) { for (size_t i = 0; i < kBatchSize; ++i) {
@ -109,13 +123,13 @@ TEST(OptimizeTest, GradientDescent) {
} }
total_loss /= kBatchSize; total_loss /= kBatchSize;
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1, AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
gemma.Weights(), grad_m, grad_v, pool); steps + 1, gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize); steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) { if (steps % 100 == 0) {
printf("Batch gradient:\n"); printf("Batch gradient:\n");
LogWeightStats(model_type, grad); LogWeightStats(model_type, weight_type, grad);
} }
if (total_loss < 0.5f) { if (total_loss < 0.5f) {
break; break;
@ -124,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) {
} }
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
LogWeightStats(model_type, gemma.Weights()); LogWeightStats(model_type, weight_type, gemma.Weights());
EXPECT_LT(steps, 200); EXPECT_LT(steps, 200);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -107,18 +107,20 @@ struct AdamUpdateT {
} // namespace } // namespace
void RandInitWeights(Model model, const ByteStorageT& weights, void RandInitWeights(Model model_type, Type weight_type,
hwy::ThreadPool& pool, const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen) { std::mt19937& gen) {
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen); CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
pool, gen);
} }
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float beta2, float epsilon, size_t t, float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) { const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t, CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
weights, grad_m, grad_v, pool); beta1, beta2, epsilon, t, weights, grad_m,
grad_v, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -19,16 +19,16 @@
#include <random> #include <random>
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
void RandInitWeights(Model model, const ByteStorageT& weights, void RandInitWeights(Model model_type, Type weight_type,
hwy::ThreadPool& pool, std::mt19937& gen); const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen);
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float beta2, float epsilon, size_t t, float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool); const ByteStorageT& grad_v, hwy::ThreadPool& pool);

View File

@ -113,7 +113,7 @@ int main(int argc, char** argv) {
gcpp::PinWorkersToCores(pool); gcpp::PinWorkersToCores(pool);
} }
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt; const std::string& prompt = prompt_args.prompt;

View File

@ -39,7 +39,7 @@ int main(int argc, char** argv) {
hwy::ThreadPool pool(num_threads); hwy::ThreadPool pool(num_threads);
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
size_t pos = 0; // KV Cache position size_t pos = 0; // KV Cache position

View File

@ -280,7 +280,7 @@ int main(int argc, char** argv) {
gcpp::PinWorkersToCores(pool); gcpp::PinWorkersToCores(pool);
} }
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) { if (!benchmark_args.goldens.path.empty()) {

View File

@ -64,4 +64,28 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
return kErrorMessageBuffer; return kErrorMessageBuffer;
} }
const char* ParseType(const std::string& type_string, Type& type) {
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
static char kErrorMessageBuffer[kNum * 8 + 100] =
"Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT
}
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string type_lc = type_string;
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kStrings[i] == type_lc) {
type = kTypes[i];
return nullptr;
}
}
return kErrorMessageBuffer;
}
} // namespace gcpp } // namespace gcpp

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include "compression/compress.h"
#include "gemma/configs.h" // IWYU pragma: export #include "gemma/configs.h" // IWYU pragma: export
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ConvertScalarTo #include "hwy/base.h" // ConvertScalarTo
@ -37,67 +38,129 @@ ByteStorageT AllocateSizeof() {
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns the return value of Func<T>().operator() called with `args`, where // Tensor types for loading weights.
// `T` is selected based on `model`. enum class Type { kF32, kBF16, kSFP };
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
// but can also be called directly when FuncT does not actually use TWeight.
// //
// This is used to implement type-erased functions such as // Note that a T prefix indicates a concrete type template argument, whereas a
// LoadCompressedWeights, which can be called from other .cc files, by calling a // T suffix indicates the argument is itself a template.
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
// be a functor because function templates cannot be passed as a template
// template argument, and we prefer to avoid the overhead of std::function.
// //
// This function avoids having to update all call sites when we extend `Model`. // `FuncT` must be a functor because function templates cannot be passed as a
template <template <typename Config> class Func, typename... Args> // template template argument, and we prefer to avoid the overhead of
decltype(auto) CallFunctorForModel(Model model, Args&&... args) { // std::function.
template <typename TWeight, template <typename TConfig> class FuncT,
typename... TArgs>
decltype(auto) CallForModel(Model model, TArgs&&... args) {
switch (model) { switch (model) {
case Model::GEMMA_TINY: case Model::GEMMA_TINY:
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...); return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_2B: case Model::GEMMA_2B:
return Func<ConfigGemma2B>()(std::forward<Args>(args)...); return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return Func<ConfigGemma7B>()(std::forward<Args>(args)...); return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B: case Model::GRIFFIN_2B:
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...); return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }
} }
// Like CallFunctorForModel, but for SIMD function templates. This is a macro // Returns the return value of FuncT<TConfig>().operator()(args),
// where `TConfig` is selected based on `model` and `weight`.
// This makes it easy to extend `Model` or `Type` without updating callers.
//
// Usage example: LoadWeights is type-erased so that it can be called from other
// .cc files. It uses this function to call the appropriate instantiation of a
// template functor LoadCompressedWeightsT<TConfig>.
template <template <typename TConfig> class FuncT, typename... TArgs>
decltype(auto) CallForModelAndWeight(Model model, Type weight,
TArgs&&... args) {
switch (weight) {
case Type::kF32:
return CallForModel<float, FuncT, TArgs...>( //
model, std::forward<TArgs>(args)...);
case Type::kBF16:
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
case Type::kSFP:
return CallForModel<SfpStream, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
default:
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
}
}
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a // because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
// normal function argument. // normal function argument. MODEL and WEIGHT are enums.
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \ #define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
switch (MODEL) { \ switch (WEIGHT) { \
case Model::GEMMA_TINY: { \ case Type::kF32: \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \ GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
ARGS; \ break; \
break; \ case Type::kBF16: \
} \ GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
case Model::GEMMA_2B: { \ break; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \ case Type::kSFP: \
ARGS; \ GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
break; \ break; \
} \ default: \
case Model::GEMMA_7B: { \ HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
} }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
// Thread-hostile. // Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training); Model& model, ModelTraining& training);
const char* ParseType(const std::string& type_string, Type& type);
static inline const char* StringFromType(Type type) {
switch (type) {
case Type::kF32:
return "f32";
case Type::kBF16:
return "bf16";
case Type::kSFP:
return "sfp";
default:
return "?";
}
}
// __builtin_sqrt is not constexpr as of Clang 17. // __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL #if HWY_COMPILER_GCC_ACTUAL

View File

@ -62,14 +62,16 @@ struct Args : public ArgsBase<Args> {
ChooseNumThreads(); ChooseNumThreads();
} }
gcpp::Model ModelType() const { return model_type; }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
ModelTraining model_training; ModelTraining model_training;
const char* parse_result = if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
ParseModelTypeAndTraining(model_type_str, model_type, model_training); model_training)) {
if (parse_result) return parse_result; return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (weights.path.empty()) { if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model."; return "Missing --weights flag, a file for the uncompressed model.";
} }
@ -86,7 +88,7 @@ struct Args : public ArgsBase<Args> {
Path weights; // uncompressed weights file location Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location Path compressed_weights; // compressed weights file location
std::string model_type_str; std::string model_type_str;
Model model_type; std::string weight_type_str;
size_t num_threads; size_t num_threads;
template <class Visitor> template <class Visitor>
@ -101,6 +103,9 @@ struct Args : public ArgsBase<Args> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n " "gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(), visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights (.sbs) file will be written.\n" "Path name where compressed weights (.sbs) file will be written.\n"
" Required argument."); " Required argument.");
@ -110,6 +115,14 @@ struct Args : public ArgsBase<Args> {
"number of suupported concurrent threads.", "number of suupported concurrent threads.",
2); 2);
} }
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
Type weight_type_;
}; };
void ShowHelp(gcpp::Args& args) { void ShowHelp(gcpp::Args& args) {
@ -132,7 +145,7 @@ namespace HWY_NAMESPACE {
template <class TConfig> template <class TConfig>
void CompressWeights(const Path& weights_path, void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type, const Path& compressed_weights_path, Model model_type,
hwy::ThreadPool& pool) { Type weight_type, hwy::ThreadPool& pool) {
if (!weights_path.Exists()) { if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.", HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str()); weights_path.path.c_str());
@ -147,7 +160,7 @@ void CompressWeights(const Path& weights_path,
// Get weights, compress, and store. // Get weights, compress, and store.
const bool scale_for_compression = TConfig::kNumTensorScales > 0; const bool scale_for_compression = TConfig::kNumTensorScales > 0;
const ByteStorageT weights_u8 = gcpp::LoadRawWeights( const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
weights_path, model_type, pool, scale_for_compression); weights_path, model_type, weight_type, pool, scale_for_compression);
WeightsF<TConfig>* weights = WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool); Compressor compressor(pool);
@ -169,9 +182,10 @@ namespace gcpp {
void Run(Args& args) { void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads); hwy::ThreadPool pool(args.num_threads);
const Model model_type = args.ModelType(); const Model model_type = args.ModelType();
GEMMA_EXPORT_AND_DISPATCH_MODEL( const Type weight_type = args.WeightType();
model_type, CompressWeights, GEMMA_EXPORT_AND_DISPATCH(
(args.weights, args.compressed_weights, model_type, pool)); model_type, weight_type, CompressWeights,
(args.weights, args.compressed_weights, model_type, weight_type, pool));
} }
} // namespace gcpp } // namespace gcpp

View File

@ -22,7 +22,6 @@
#include <array> #include <array>
#include "compression/compress.h" // SfpStream
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp { namespace gcpp {
@ -42,18 +41,10 @@ namespace gcpp {
#define GEMMA_MAX_THREADS 128 #define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS #endif // !GEMMA_MAX_THREADS
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK; static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS; static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
enum class LayerAttentionType { enum class LayerAttentionType {
@ -82,7 +73,10 @@ constexpr size_t NumLayersOfTypeBefore(
return count; return count;
} }
template <typename TWeight>
struct ConfigGemma7B { struct ConfigGemma7B {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig = static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
@ -111,10 +105,12 @@ struct ConfigGemma7B {
static constexpr bool kUseLocalAttention = false; static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true; static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
}; };
template <typename TWeight>
struct ConfigGemma2B { struct ConfigGemma2B {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig = static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
@ -143,10 +139,12 @@ struct ConfigGemma2B {
static constexpr bool kUseLocalAttention = false; static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true; static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
}; };
template <typename TWeight>
struct ConfigGemmaTiny { struct ConfigGemmaTiny {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32; static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16; static constexpr int kVocabSize = 16;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig = static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
@ -175,10 +173,12 @@ struct ConfigGemmaTiny {
static constexpr bool kUseLocalAttention = false; static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true; static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
}; };
template <typename TWeight>
struct ConfigGriffin2B { struct ConfigGriffin2B {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention // Griffin uses local attention, so kSeqLen is actually the local attention
// window. // window.
static constexpr int kSeqLen = 2048; static constexpr int kSeqLen = 2048;
@ -235,7 +235,6 @@ struct ConfigGriffin2B {
static constexpr bool kUseLocalAttention = true; static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false; static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140; static constexpr int kNumTensorScales = 140;
using WeightT = GEMMA_WEIGHT_T;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -15,13 +15,20 @@
#include "gemma/cross_entropy.h" #include "gemma/cross_entropy.h"
#include <stddef.h>
#include <stdio.h>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <functional>
#include <regex> // NOLINT #include <regex> // NOLINT
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h"
namespace gcpp { namespace gcpp {
namespace { namespace {
@ -63,7 +70,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
auto stream_token = [](int, float) { return true; }; auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; }; auto accept_token = [](int) { return true; };
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType()); // TWeight is unused, but we have to pass it to Config*.
const int vocab_size =
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1; size_t pos = 1;
std::function<int(const float*, size_t)> sample_token = std::function<int(const float*, size_t)> sample_token =

View File

@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#include <stddef.h>
#include <vector> #include <vector>
#include "gemma/gemma.h" #include "gemma/gemma.h"

View File

@ -147,8 +147,10 @@ struct CreateKVCache {
} // namespace } // namespace
KVCache KVCache::Create(Model type) { KVCache KVCache::Create(Model model_type) {
return CallFunctorForModel<CreateKVCache>(type); // TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
} }
class GemmaTokenizer::Impl { class GemmaTokenizer::Impl {
@ -727,12 +729,12 @@ Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
} // namespace } // namespace
template <class TConfig> template <class TConfig>
void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8, const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info, hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8); const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations = auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8); GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
@ -871,23 +873,31 @@ struct AllocateDecode {
} // namespace } // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool) Type weight_type, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), model_type_(model_type) { : pool_(pool),
weights_u8_ = LoadWeights(weights, model_type, pool); tokenizer_(tokenizer_path),
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type); model_type_(model_type),
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type); weight_type_(weight_type) {
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
hwy::ThreadPool& pool) hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) { : pool_(pool),
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool); tokenizer_(std::move(tokenizer)),
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type); model_type_(model_type),
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type); weight_type_(weight_type) {
weights_u8_ =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
} }
Gemma::~Gemma() { Gemma::~Gemma() {
CallFunctorForModel<DeleteLayersPtrs>(model_type_, weights_u8_); CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
weights_u8_);
} }
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,
@ -896,8 +906,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH_MODEL( GEMMA_EXPORT_AND_DISPATCH(
model_type_, Generate, model_type_, weight_type_, GenerateT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info, layers_output)); kv_cache, pool_, timing_info, layers_output));

View File

@ -107,10 +107,11 @@ using LayersOutputT =
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool); Type weight_type, hwy::ThreadPool& pool);
// Allocates weights, caller is responsible for filling them. // Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool); Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
hwy::ThreadPool& pool);
~Gemma(); ~Gemma();
Model ModelType() const { return model_type_; } Model ModelType() const { return model_type_; }
@ -136,6 +137,7 @@ class Gemma {
ByteStorageT prefill_u8_; ByteStorageT prefill_u8_;
ByteStorageT decode_u8_; ByteStorageT decode_u8_;
Model model_type_; Model model_type_;
Type weight_type_;
}; };
// DEPRECATED, call Gemma::Generate directly. // DEPRECATED, call Gemma::Generate directly.

View File

@ -38,8 +38,7 @@ class GemmaTest : public ::testing::Test {
: weights("./2b-it-mqa.sbs"), : weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"), tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)), pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
model_type(gcpp::Model::GEMMA_2B), model(tokenizer, weights, model_type, weight_type, pool) {
model(tokenizer, weights, model_type, pool) {
KVCache kv_cache = KVCache::Create(model_type); KVCache kv_cache = KVCache::Create(model_type);
} }
@ -96,7 +95,8 @@ class GemmaTest : public ::testing::Test {
gcpp::Path tokenizer; gcpp::Path tokenizer;
gcpp::KVCache kv_cache; gcpp::KVCache kv_cache;
hwy::ThreadPool pool; hwy::ThreadPool pool;
gcpp::Model model_type = {}; gcpp::Model model_type = gcpp::Model::GEMMA_2B;
gcpp::Type weight_type = gcpp::Type::kSFP;
gcpp::Gemma model; gcpp::Gemma model;
}; };

View File

@ -71,7 +71,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
<< hwy::VectorBytes() * 8 << " bits)" << "\n" << hwy::VectorBytes() * 8 << " bits)" << "\n"
<< "Compiled config : " << CompiledConfig() << "\n" << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : " << "Weight Type : "
<< gcpp::TypeName(gcpp::GemmaWeightT()) << "\n" << gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : " << "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
} }
@ -251,8 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool); PinWorkersToCores(pool);
} }
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(loader.ModelType()); KVCache kv_cache = KVCache::Create(loader.ModelType());
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
@ -278,7 +277,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity, model, loader.ModelTrainingType(), kv_cache, pool, inference,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }

View File

@ -157,8 +157,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool); PinWorkersToCores(pool);
} }
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line); JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);

View File

@ -173,10 +173,11 @@ struct LoadRawWeightsT {
#undef SCALE_WEIGHTS #undef SCALE_WEIGHTS
} // namespace } // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model, ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
hwy::ThreadPool& pool, bool scale_for_compression) { Type weight_type, hwy::ThreadPool& pool,
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool, bool scale_for_compression) {
scale_for_compression); return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
} }
namespace { namespace {
@ -227,17 +228,18 @@ struct LoadCompressedWeightsT {
}; };
} // namespace } // namespace
ByteStorageT LoadCompressedWeights(const Path& weights, Model model, ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
hwy::ThreadPool& pool) { Type weight_type, hwy::ThreadPool& pool) {
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool); return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type,
weights, pool);
} }
ByteStorageT LoadWeights(const Path& weights, Model model, ByteStorageT LoadWeights(const Path& weights, Model model_type,
hwy::ThreadPool& pool) { Type weight_type, hwy::ThreadPool& pool) {
if constexpr (kWeightsAreCompressed) { if constexpr (kWeightsAreCompressed) {
return LoadCompressedWeights(weights, model, pool); return LoadCompressedWeights(weights, model_type, weight_type, pool);
} else { } else {
return LoadRawWeights(weights, model, pool, return LoadRawWeights(weights, model_type, weight_type, pool,
/*scale_for_compression=*/false); /*scale_for_compression=*/false);
} }
} }
@ -274,8 +276,9 @@ struct LogWeightStatsT {
}; };
} // namespace } // namespace
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) { void LogWeightStats(gcpp::Model model_type, Type weight_type,
CallFunctorForModel<LogWeightStatsT>(model, weights); const ByteStorageT& weights) {
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -129,21 +129,17 @@ using WeightsF = Weights<float, TConfig>;
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Compressed // Compressed
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
// not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
template <class TConfig>
using WeightF32OrBF16T =
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
hwy::bfloat16_t>;
template <class TConfig> template <class TConfig>
struct CompressedLayer { struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::LayerF<TConfig>; using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT; using Weight = typename TConfig::Weight;
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>; // If weights are f32, also f32; otherwise at least bf16. Useful for ops that
// do not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
static constexpr size_t kHeads = TLayer::kHeads; static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kKVHeads = TLayer::kKVHeads; static constexpr size_t kKVHeads = TLayer::kKVHeads;
@ -166,29 +162,29 @@ struct CompressedLayer {
union { union {
struct { struct {
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w; ArrayT<Weight, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w; ArrayT<Weight, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiasDim> attention_output_biases; ArrayT<float, kAOBiasDim> attention_output_biases;
}; };
struct { struct {
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w; ArrayT<Weight, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> linear_x_biases; ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w; ArrayT<Weight, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> linear_y_biases; ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w; ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> linear_out_biases; ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w; ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> conv_biases; ArrayT<float, kGriffinDim> conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w; ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> gate_biases; ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a; ArrayT<float, kGriffinDim> a;
} griffin; } griffin;
}; };
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w; ArrayT<Weight, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w; ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT. // We don't yet have an RMSNorm that accepts all Weight.
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
@ -241,7 +237,7 @@ template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>, using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>; WeightsF<TConfig>>;
// Call via CallFunctorForModel. // TODO: can we use TConfig::Weight instead of T?
template <typename T, typename TConfig> template <typename T, typename TConfig>
struct AllocateWeights { struct AllocateWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const { ByteStorageT operator()(hwy::ThreadPool& pool) const {
@ -335,14 +331,15 @@ class WeightsWrapper {
}; };
// For use by compress_weights.cc. // For use by compress_weights.cc.
ByteStorageT LoadRawWeights(const Path& weights, Model model, ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
hwy::ThreadPool& pool, bool scale_for_compression); Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression);
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed. // For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
ByteStorageT LoadWeights(const Path& weights, Model model, ByteStorageT LoadWeights(const Path& weights, Model model_type,
hwy::ThreadPool& pool); Type weight_type, hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights); void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Iterators // Iterators

View File

@ -21,19 +21,17 @@
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#if HWY_OS_LINUX #if HWY_OS_LINUX
#include <sched.h> #include <sched.h>
#include <cctype>
#include <cerrno> // IDE does not recognize errno.h as providing errno.
#include <string>
#endif // HWY_OS_LINUX #endif // HWY_OS_LINUX
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::clamp #include <algorithm> // std::clamp
#include <thread> // NOLINT> #include <string>
#include <thread> // NOLINT>
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" #include "util/args.h"
@ -49,14 +47,10 @@ static inline const char* CompiledConfig() {
return "msan"; return "msan";
} else if (HWY_IS_TSAN) { } else if (HWY_IS_TSAN) {
return "tsan"; return "tsan";
#if defined(HWY_IS_HWASAN)
} else if (HWY_IS_HWASAN) { } else if (HWY_IS_HWASAN) {
return "hwasan"; return "hwasan";
#endif
#if defined(HWY_IS_UBSAN)
} else if (HWY_IS_UBSAN) { } else if (HWY_IS_UBSAN) {
return "ubsan"; return "ubsan";
#endif
} else if (HWY_IS_DEBUG_BUILD) { } else if (HWY_IS_DEBUG_BUILD) {
return "dbg"; return "dbg";
} else { } else {
@ -172,15 +166,15 @@ class AppArgs : public ArgsBase<AppArgs> {
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Model ModelType() const { return model_type; }
gcpp::ModelTraining ModelTraining() const { return model_training; }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
const char* parse_result = if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
ParseModelTypeAndTraining(model_type_str, model_type, model_training); model_training_)) {
if (parse_result) return parse_result; return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required."; return "Missing --tokenizer flag, a file for the tokenizer is required.";
} }
@ -209,8 +203,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path weights; // weights file location Path weights; // weights file location
Path compressed_weights; Path compressed_weights;
std::string model_type_str; std::string model_type_str;
Model model_type; std::string weight_type_str;
enum ModelTraining model_training;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
@ -227,9 +220,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n " "gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument."); " Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
} }
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
ModelTraining model_training_;
Type weight_type_;
}; };
static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
loader.WeightType(), pool);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }