Support all weight types in a single binary.

This changes the command line flags, but the default value retains the previous behavior.

Also add a CreateGemma helper to enable extra args without interface changes.

PiperOrigin-RevId: 641266411
This commit is contained in:
Jan Wassenberg 2024-06-07 09:04:06 -07:00 committed by Copybara-Service
parent 24db2ff725
commit f9b390b134
27 changed files with 372 additions and 248 deletions

View File

@ -116,6 +116,7 @@ cc_library(
"gemma/cross_entropy.h",
],
deps = [
":common",
":gemma_lib",
],
)

View File

@ -76,17 +76,6 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
# Allowable types for WEIGHT_TYPE:
# float - slow, not recommended
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
# SfpStream - 8-bit switched floating point (recommended)
# NuqStream - experimental, work-in-progress
option(WEIGHT_TYPE "Set weight type" "")
if (WEIGHT_TYPE)
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
endif()
FetchContent_GetProperties(sentencepiece)
## Library Target

View File

@ -105,18 +105,12 @@ the resulting file as `--weights` and the desired .sbs name as the
There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system):
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
`gemma.h`.
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet.
In the medium term both of these will likely be deprecated in favor of handling
options at runtime - allowing for multiple weight compression schemes in a single
build and dynamically resizes the KV cache as needed.
In the medium term this will likely be deprecated in favor of handling options
at runtime - dynamically resizing the KV cache as needed.
## Using gemma.cpp as a Library (Advanced)

View File

@ -138,33 +138,16 @@ convenient directory location (e.g. the `build/` directory in this repo).
The build system uses [CMake](https://cmake.org/). To build the gemma inference
runtime, create a build directory and generate the build files using `cmake`
from the top-level project directory. Note if you previous ran `cmake` and are
re-running with a different setting, be sure to clean out the `build/` directory
with `rm -rf build/*` (warning this will delete any other files in the `build/`
directory.
For the 8-bit switched floating point weights (sfp), run cmake with no options:
re-running with a different setting, be sure to delete all files in the `build/`
directory with `rm -rf build/*`.
#### Unix-like Platforms
```sh
cmake -B build
```
**or** if you downloaded bfloat16 weights (any model *without* `-sfp` in the
name), instead of running cmake with no options as above, run cmake with
WEIGHT_TYPE set to [highway's](https://github.com/google/highway)
`hwy::bfloat16_t` type. Alternatively, you can also add
`-DGEMMA_WEIGHT_T=hwy::bfloat16_t` to the C++ compiler flags.
We intend to soon support all weight types without requiring extra flags. Note
that we recommend using `-sfp` weights instead of bfloat16 for faster inference.
```sh
cmake -B build -DWEIGHT_TYPE=hwy::bfloat16_t
```
After running whichever of the above `cmake` invocations that is appropriate for
your weights, you can enter the `build/` directory and run `make` to build the
`./gemma` executable:
After running `cmake`, you can enter the `build/` directory and run `make` to
build the `./gemma` executable:
```sh
# Configure `build` directory
@ -221,11 +204,12 @@ You can now run `gemma` from inside the `build/` directory.
`gemma` has the following required arguments:
| Argument | Description | Example value |
| ------------- | ---------------------------- | -------------------------- |
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) |
| `--weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) |
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` |
Argument | Description | Example value
--------------- | ---------------------------- | -----------------------
`--model` | The model type. | `2b-it` ... (see below)
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
`--weight_type` | The compressed weight type. | `sfp`
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
`gemma` is invoked as:
@ -233,6 +217,7 @@ You can now run `gemma` from inside the `build/` directory.
./gemma \
--tokenizer [tokenizer file] \
--weights [compressed weights file] \
--weight_type [f32 or bf16 or sfp] \
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
```
@ -245,8 +230,7 @@ Example invocation for the following configuration:
```sh
./gemma \
--tokenizer tokenizer.spm \
--weights 2b-it-sfp.sbs \
--model 2b-it
--weights 2b-it-sfp.sbs --weight_type sfp --model 2b-it
```
### RecurrentGemma
@ -270,14 +254,12 @@ Step 1, and run the binary as follows:
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
The most common problem is that `cmake` was built with the wrong weight type and
`gemma` is attempting to load `bfloat16` weights (`2b-it`, `2b-pt`, `7b-it`,
`7b-pt`) using the default switched floating point (sfp) or vice versa. Revisit
step #3 and check that the `cmake` command used to build `gemma` was correct for
the weights that you downloaded.
The most common problem is that the `--weight_type` argument does not match that
of the model file. Revisit step #3 and check which weights you downloaded.
In the future we will handle model format handling from compile time to runtime
to simplify this.
Note that we have already moved weight type from a compile-time decision to a
runtime argument. In a subsequent step, we plan to bake this information into
the weights.
**Problems building in Windows / Visual Studio**

View File

@ -21,7 +21,6 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
@ -44,6 +43,7 @@
#endif
#include "gemma/ops.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {

View File

@ -15,6 +15,11 @@
#include "backprop/backward.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
@ -29,7 +34,6 @@
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
@ -57,11 +61,11 @@ void CrossEntropyLossBackwardPassT(Model model,
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B>(
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, grad, backward, pool);
break;
case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, grad, backward, pool);
break;
default:

View File

@ -15,6 +15,11 @@
#include "backprop/forward.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
@ -29,7 +34,6 @@
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <typename TConfig>
float CrossEntropyLossForwardPass(const Prompt& prompt,
@ -51,10 +55,10 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B>(
prompt, weights, forward, pool);
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
forward, pool);
case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));

View File

@ -13,18 +13,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <stddef.h>
#include <limits>
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "gtest/gtest.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -35,11 +40,17 @@ TEST(OptimizeTest, GradientDescent) {
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
Type weight_type = Type::kF32;
ByteStorageT grad =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad_m =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad_v =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
KVCache kv_cache = KVCache::Create(model_type);
size_t max_tokens = 32;
size_t max_generated_tokens = 16;
@ -47,7 +58,7 @@ TEST(OptimizeTest, GradientDescent) {
int verbosity = 0;
const auto accept_token = [](int) { return true; };
Gemma gemma(GemmaTokenizer(), model_type, pool);
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
@ -76,12 +87,14 @@ TEST(OptimizeTest, GradientDescent) {
return ok;
};
RandInitWeights(model_type, gemma.Weights(), pool, gen);
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
pool);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
pool);
printf("Initial weights:\n");
LogWeightStats(model_type, gemma.Weights());
LogWeightStats(model_type, weight_type, gemma.Weights());
constexpr size_t kBatchSize = 8;
const float alpha = 0.001f;
@ -96,7 +109,8 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
@ -109,13 +123,13 @@ TEST(OptimizeTest, GradientDescent) {
}
total_loss /= kBatchSize;
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
gemma.Weights(), grad_m, grad_v, pool);
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
LogWeightStats(model_type, grad);
LogWeightStats(model_type, weight_type, grad);
}
if (total_loss < 0.5f) {
break;
@ -124,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) {
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, gemma.Weights());
LogWeightStats(model_type, weight_type, gemma.Weights());
EXPECT_LT(steps, 200);
EXPECT_EQ(num_ok, kBatchSize);
}

View File

@ -107,18 +107,20 @@ struct AdamUpdateT {
} // namespace
void RandInitWeights(Model model, const ByteStorageT& weights,
hwy::ThreadPool& pool,
void RandInitWeights(Model model_type, Type weight_type,
const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen) {
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
pool, gen);
}
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
weights, grad_m, grad_v, pool);
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
beta1, beta2, epsilon, t, weights, grad_m,
grad_v, pool);
}
} // namespace gcpp

View File

@ -19,16 +19,16 @@
#include <random>
#include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void RandInitWeights(Model model, const ByteStorageT& weights,
hwy::ThreadPool& pool, std::mt19937& gen);
void RandInitWeights(Model model_type, Type weight_type,
const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen);
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool);

View File

@ -113,7 +113,7 @@ int main(int argc, char** argv) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt;

View File

@ -39,7 +39,7 @@ int main(int argc, char** argv) {
hwy::ThreadPool pool(num_threads);
// Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
size_t pos = 0; // KV Cache position

View File

@ -280,7 +280,7 @@ int main(int argc, char** argv) {
gcpp::PinWorkersToCores(pool);
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) {

View File

@ -64,4 +64,28 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
return kErrorMessageBuffer;
}
const char* ParseType(const std::string& type_string, Type& type) {
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
static char kErrorMessageBuffer[kNum * 8 + 100] =
"Invalid or missing type, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT
}
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string type_lc = type_string;
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kStrings[i] == type_lc) {
type = kTypes[i];
return nullptr;
}
}
return kErrorMessageBuffer;
}
} // namespace gcpp

View File

@ -21,6 +21,7 @@
#include <string>
#include "compression/compress.h"
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ConvertScalarTo
@ -37,67 +38,129 @@ ByteStorageT AllocateSizeof() {
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns the return value of Func<T>().operator() called with `args`, where
// `T` is selected based on `model`.
// Tensor types for loading weights.
enum class Type { kF32, kBF16, kSFP };
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
// but can also be called directly when FuncT does not actually use TWeight.
//
// This is used to implement type-erased functions such as
// LoadCompressedWeights, which can be called from other .cc files, by calling a
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
// be a functor because function templates cannot be passed as a template
// template argument, and we prefer to avoid the overhead of std::function.
// Note that a T prefix indicates a concrete type template argument, whereas a
// T suffix indicates the argument is itself a template.
//
// This function avoids having to update all call sites when we extend `Model`.
template <template <typename Config> class Func, typename... Args>
decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
// `FuncT` must be a functor because function templates cannot be passed as a
// template template argument, and we prefer to avoid the overhead of
// std::function.
template <typename TWeight, template <typename TConfig> class FuncT,
typename... TArgs>
decltype(auto) CallForModel(Model model, TArgs&&... args) {
switch (model) {
case Model::GEMMA_TINY:
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...);
return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_2B:
return Func<ConfigGemma2B>()(std::forward<Args>(args)...);
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_7B:
return Func<ConfigGemma7B>()(std::forward<Args>(args)...);
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B:
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...);
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
// Like CallFunctorForModel, but for SIMD function templates. This is a macro
// Returns the return value of FuncT<TConfig>().operator()(args),
// where `TConfig` is selected based on `model` and `weight`.
// This makes it easy to extend `Model` or `Type` without updating callers.
//
// Usage example: LoadWeights is type-erased so that it can be called from other
// .cc files. It uses this function to call the appropriate instantiation of a
// template functor LoadCompressedWeightsT<TConfig>.
template <template <typename TConfig> class FuncT, typename... TArgs>
decltype(auto) CallForModelAndWeight(Model model, Type weight,
TArgs&&... args) {
switch (weight) {
case Type::kF32:
return CallForModel<float, FuncT, TArgs...>( //
model, std::forward<TArgs>(args)...);
case Type::kBF16:
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
case Type::kSFP:
return CallForModel<SfpStream, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
default:
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
}
}
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
// normal function argument.
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
// normal function argument. MODEL and WEIGHT are enums.
#define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
switch (WEIGHT) { \
case Type::kF32: \
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
break; \
case Type::kBF16: \
GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
break; \
case Type::kSFP: \
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
break; \
default: \
HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
}
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
const char* ParseType(const std::string& type_string, Type& type);
static inline const char* StringFromType(Type type) {
switch (type) {
case Type::kF32:
return "f32";
case Type::kBF16:
return "bf16";
case Type::kSFP:
return "sfp";
default:
return "?";
}
}
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL

View File

@ -62,14 +62,16 @@ struct Args : public ArgsBase<Args> {
ChooseNumThreads();
}
gcpp::Model ModelType() const { return model_type; }
// Returns error string or nullptr if OK.
const char* Validate() {
ModelTraining model_training;
const char* parse_result =
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
if (parse_result) return parse_result;
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model.";
}
@ -86,7 +88,7 @@ struct Args : public ArgsBase<Args> {
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type_str;
Model model_type;
std::string weight_type_str;
size_t num_threads;
template <class Visitor>
@ -101,6 +103,9 @@ struct Args : public ArgsBase<Args> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights (.sbs) file will be written.\n"
" Required argument.");
@ -110,6 +115,14 @@ struct Args : public ArgsBase<Args> {
"number of suupported concurrent threads.",
2);
}
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
Type weight_type_;
};
void ShowHelp(gcpp::Args& args) {
@ -132,7 +145,7 @@ namespace HWY_NAMESPACE {
template <class TConfig>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type,
hwy::ThreadPool& pool) {
Type weight_type, hwy::ThreadPool& pool) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
@ -147,7 +160,7 @@ void CompressWeights(const Path& weights_path,
// Get weights, compress, and store.
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
weights_path, model_type, pool, scale_for_compression);
weights_path, model_type, weight_type, pool, scale_for_compression);
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
@ -169,9 +182,10 @@ namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
const Model model_type = args.ModelType();
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type, CompressWeights,
(args.weights, args.compressed_weights, model_type, pool));
const Type weight_type = args.WeightType();
GEMMA_EXPORT_AND_DISPATCH(
model_type, weight_type, CompressWeights,
(args.weights, args.compressed_weights, model_type, weight_type, pool));
}
} // namespace gcpp

View File

@ -22,7 +22,6 @@
#include <array>
#include "compression/compress.h" // SfpStream
#include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp {
@ -42,18 +41,10 @@ namespace gcpp {
#define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
enum class LayerAttentionType {
@ -82,7 +73,10 @@ constexpr size_t NumLayersOfTypeBefore(
return count;
}
template <typename TWeight>
struct ConfigGemma7B {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
@ -111,10 +105,12 @@ struct ConfigGemma7B {
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
template <typename TWeight>
struct ConfigGemma2B {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
@ -143,10 +139,12 @@ struct ConfigGemma2B {
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
template <typename TWeight>
struct ConfigGemmaTiny {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
@ -175,10 +173,12 @@ struct ConfigGemmaTiny {
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;
using WeightT = GEMMA_WEIGHT_T;
};
template <typename TWeight>
struct ConfigGriffin2B {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
@ -235,7 +235,6 @@ struct ConfigGriffin2B {
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
using WeightT = GEMMA_WEIGHT_T;
};
} // namespace gcpp

View File

@ -15,13 +15,20 @@
#include "gemma/cross_entropy.h"
#include <stddef.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h"
namespace gcpp {
namespace {
@ -63,7 +70,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
auto stream_token = [](int, float) { return true; };
auto accept_token = [](int) { return true; };
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
// TWeight is unused, but we have to pass it to Config*.
const int vocab_size =
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
std::function<int(const float*, size_t)> sample_token =

View File

@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
#include <stddef.h>
#include <vector>
#include "gemma/gemma.h"

View File

@ -147,8 +147,10 @@ struct CreateKVCache {
} // namespace
KVCache KVCache::Create(Model type) {
return CallFunctorForModel<CreateKVCache>(type);
KVCache KVCache::Create(Model model_type) {
// TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
}
class GemmaTokenizer::Impl {
@ -727,12 +729,12 @@ Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
} // namespace
template <class TConfig>
void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const ByteStorageT& decode_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
@ -871,23 +873,31 @@ struct AllocateDecode {
} // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), model_type_(model_type) {
weights_u8_ = LoadWeights(weights, model_type, pool);
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
Type weight_type, hwy::ThreadPool& pool)
: pool_(pool),
tokenizer_(tokenizer_path),
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
: pool_(pool),
tokenizer_(std::move(tokenizer)),
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
}
Gemma::~Gemma() {
CallFunctorForModel<DeleteLayersPtrs>(model_type_, weights_u8_);
CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
weights_u8_);
}
void Gemma::Generate(const RuntimeConfig& runtime_config,
@ -896,8 +906,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
LayersOutputT* layers_output) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type_, Generate,
GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info, layers_output));

View File

@ -107,10 +107,11 @@ using LayersOutputT =
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool);
Type weight_type, hwy::ThreadPool& pool);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
hwy::ThreadPool& pool);
~Gemma();
Model ModelType() const { return model_type_; }
@ -136,6 +137,7 @@ class Gemma {
ByteStorageT prefill_u8_;
ByteStorageT decode_u8_;
Model model_type_;
Type weight_type_;
};
// DEPRECATED, call Gemma::Generate directly.

View File

@ -38,8 +38,7 @@ class GemmaTest : public ::testing::Test {
: weights("./2b-it-mqa.sbs"),
tokenizer("./tokenizer.spm"),
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
model_type(gcpp::Model::GEMMA_2B),
model(tokenizer, weights, model_type, pool) {
model(tokenizer, weights, model_type, weight_type, pool) {
KVCache kv_cache = KVCache::Create(model_type);
}
@ -96,7 +95,8 @@ class GemmaTest : public ::testing::Test {
gcpp::Path tokenizer;
gcpp::KVCache kv_cache;
hwy::ThreadPool pool;
gcpp::Model model_type = {};
gcpp::Model model_type = gcpp::Model::GEMMA_2B;
gcpp::Type weight_type = gcpp::Type::kSFP;
gcpp::Gemma model;
};

View File

@ -71,7 +71,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
<< "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::TypeName(gcpp::GemmaWeightT()) << "\n"
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
}
@ -251,8 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool);
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(loader.ModelType());
if (app.verbosity >= 1) {
@ -278,7 +277,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
ReplGemma(
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
model, loader.ModelTrainingType(), kv_cache, pool, inference,
app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
}

View File

@ -157,8 +157,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool);
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);

View File

@ -173,10 +173,11 @@ struct LoadRawWeightsT {
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model,
hwy::ThreadPool& pool, bool scale_for_compression) {
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool,
scale_for_compression);
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression) {
return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
}
namespace {
@ -227,17 +228,18 @@ struct LoadCompressedWeightsT {
};
} // namespace
ByteStorageT LoadCompressedWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool);
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type,
weights, pool);
}
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
if constexpr (kWeightsAreCompressed) {
return LoadCompressedWeights(weights, model, pool);
return LoadCompressedWeights(weights, model_type, weight_type, pool);
} else {
return LoadRawWeights(weights, model, pool,
return LoadRawWeights(weights, model_type, weight_type, pool,
/*scale_for_compression=*/false);
}
}
@ -274,8 +276,9 @@ struct LogWeightStatsT {
};
} // namespace
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
CallFunctorForModel<LogWeightStatsT>(model, weights);
void LogWeightStats(gcpp::Model model_type, Type weight_type,
const ByteStorageT& weights) {
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
}
} // namespace gcpp

View File

@ -129,21 +129,17 @@ using WeightsF = Weights<float, TConfig>;
// ----------------------------------------------------------------------------
// Compressed
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
// not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
template <class TConfig>
using WeightF32OrBF16T =
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
hwy::bfloat16_t>;
template <class TConfig>
struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT;
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>;
using Weight = typename TConfig::Weight;
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
// do not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kKVHeads = TLayer::kKVHeads;
@ -166,29 +162,29 @@ struct CompressedLayer {
union {
struct {
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<Weight, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<Weight, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiasDim> attention_output_biases;
};
struct {
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a;
} griffin;
};
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT.
ArrayT<Weight, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all Weight.
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
@ -241,7 +237,7 @@ template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>;
// Call via CallFunctorForModel.
// TODO: can we use TConfig::Weight instead of T?
template <typename T, typename TConfig>
struct AllocateWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const {
@ -335,14 +331,15 @@ class WeightsWrapper {
};
// For use by compress_weights.cc.
ByteStorageT LoadRawWeights(const Path& weights, Model model,
hwy::ThreadPool& pool, bool scale_for_compression);
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression);
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool);
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights);
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
// ----------------------------------------------------------------------------
// Iterators

View File

@ -21,19 +21,17 @@
#include "hwy/contrib/thread_pool/thread_pool.h"
#if HWY_OS_LINUX
#include <sched.h>
#include <cctype>
#include <cerrno> // IDE does not recognize errno.h as providing errno.
#include <string>
#endif // HWY_OS_LINUX
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <thread> // NOLINT>
#include <string>
#include <thread> // NOLINT>
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/args.h"
@ -49,14 +47,10 @@ static inline const char* CompiledConfig() {
return "msan";
} else if (HWY_IS_TSAN) {
return "tsan";
#if defined(HWY_IS_HWASAN)
} else if (HWY_IS_HWASAN) {
return "hwasan";
#endif
#if defined(HWY_IS_UBSAN)
} else if (HWY_IS_UBSAN) {
return "ubsan";
#endif
} else if (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
@ -172,15 +166,15 @@ class AppArgs : public ArgsBase<AppArgs> {
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
gcpp::Model ModelType() const { return model_type; }
gcpp::ModelTraining ModelTraining() const { return model_training; }
// Returns error string or nullptr if OK.
const char* Validate() {
const char* parse_result =
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
if (parse_result) return parse_result;
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is required.";
}
@ -209,8 +203,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path weights; // weights file location
Path compressed_weights;
std::string model_type_str;
Model model_type;
enum ModelTraining model_training;
std::string weight_type_str;
template <class Visitor>
void ForEach(const Visitor& visitor) {
@ -227,9 +220,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
}
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
ModelTraining model_training_;
Type weight_type_;
};
static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
loader.WeightType(), pool);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }