mirror of https://github.com/google/gemma.cpp.git
Support all weight types in a single binary.
This changes the command line flags, but the default value retains the previous behavior. Also add a CreateGemma helper to enable extra args without interface changes. PiperOrigin-RevId: 641266411
This commit is contained in:
parent
24db2ff725
commit
f9b390b134
|
|
@ -116,6 +116,7 @@ cc_library(
|
||||||
"gemma/cross_entropy.h",
|
"gemma/cross_entropy.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -76,17 +76,6 @@ if(NOT CMAKE_BUILD_TYPE)
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Allowable types for WEIGHT_TYPE:
|
|
||||||
# float - slow, not recommended
|
|
||||||
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
|
|
||||||
# SfpStream - 8-bit switched floating point (recommended)
|
|
||||||
# NuqStream - experimental, work-in-progress
|
|
||||||
option(WEIGHT_TYPE "Set weight type" "")
|
|
||||||
|
|
||||||
if (WEIGHT_TYPE)
|
|
||||||
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
FetchContent_GetProperties(sentencepiece)
|
FetchContent_GetProperties(sentencepiece)
|
||||||
|
|
||||||
## Library Target
|
## Library Target
|
||||||
|
|
|
||||||
|
|
@ -105,18 +105,12 @@ the resulting file as `--weights` and the desired .sbs name as the
|
||||||
There are several compile-time flags to be aware of (note these may or may not
|
There are several compile-time flags to be aware of (note these may or may not
|
||||||
be exposed to the build system):
|
be exposed to the build system):
|
||||||
|
|
||||||
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
|
|
||||||
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
|
|
||||||
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
|
|
||||||
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
|
|
||||||
`gemma.h`.
|
|
||||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||||
through `CMakeLists.txt` yet.
|
through `CMakeLists.txt` yet.
|
||||||
|
|
||||||
In the medium term both of these will likely be deprecated in favor of handling
|
In the medium term this will likely be deprecated in favor of handling options
|
||||||
options at runtime - allowing for multiple weight compression schemes in a single
|
at runtime - dynamically resizing the KV cache as needed.
|
||||||
build and dynamically resizes the KV cache as needed.
|
|
||||||
|
|
||||||
## Using gemma.cpp as a Library (Advanced)
|
## Using gemma.cpp as a Library (Advanced)
|
||||||
|
|
||||||
|
|
|
||||||
52
README.md
52
README.md
|
|
@ -138,33 +138,16 @@ convenient directory location (e.g. the `build/` directory in this repo).
|
||||||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||||
runtime, create a build directory and generate the build files using `cmake`
|
runtime, create a build directory and generate the build files using `cmake`
|
||||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||||
re-running with a different setting, be sure to clean out the `build/` directory
|
re-running with a different setting, be sure to delete all files in the `build/`
|
||||||
with `rm -rf build/*` (warning this will delete any other files in the `build/`
|
directory with `rm -rf build/*`.
|
||||||
directory.
|
|
||||||
|
|
||||||
For the 8-bit switched floating point weights (sfp), run cmake with no options:
|
|
||||||
|
|
||||||
#### Unix-like Platforms
|
#### Unix-like Platforms
|
||||||
```sh
|
```sh
|
||||||
cmake -B build
|
cmake -B build
|
||||||
```
|
```
|
||||||
|
|
||||||
**or** if you downloaded bfloat16 weights (any model *without* `-sfp` in the
|
After running `cmake`, you can enter the `build/` directory and run `make` to
|
||||||
name), instead of running cmake with no options as above, run cmake with
|
build the `./gemma` executable:
|
||||||
WEIGHT_TYPE set to [highway's](https://github.com/google/highway)
|
|
||||||
`hwy::bfloat16_t` type. Alternatively, you can also add
|
|
||||||
`-DGEMMA_WEIGHT_T=hwy::bfloat16_t` to the C++ compiler flags.
|
|
||||||
|
|
||||||
We intend to soon support all weight types without requiring extra flags. Note
|
|
||||||
that we recommend using `-sfp` weights instead of bfloat16 for faster inference.
|
|
||||||
|
|
||||||
```sh
|
|
||||||
cmake -B build -DWEIGHT_TYPE=hwy::bfloat16_t
|
|
||||||
```
|
|
||||||
|
|
||||||
After running whichever of the above `cmake` invocations that is appropriate for
|
|
||||||
your weights, you can enter the `build/` directory and run `make` to build the
|
|
||||||
`./gemma` executable:
|
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Configure `build` directory
|
# Configure `build` directory
|
||||||
|
|
@ -221,11 +204,12 @@ You can now run `gemma` from inside the `build/` directory.
|
||||||
|
|
||||||
`gemma` has the following required arguments:
|
`gemma` has the following required arguments:
|
||||||
|
|
||||||
| Argument | Description | Example value |
|
Argument | Description | Example value
|
||||||
| ------------- | ---------------------------- | -------------------------- |
|
--------------- | ---------------------------- | -----------------------
|
||||||
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) |
|
`--model` | The model type. | `2b-it` ... (see below)
|
||||||
| `--weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) |
|
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||||
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` |
|
`--weight_type` | The compressed weight type. | `sfp`
|
||||||
|
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||||
|
|
||||||
`gemma` is invoked as:
|
`gemma` is invoked as:
|
||||||
|
|
||||||
|
|
@ -233,6 +217,7 @@ You can now run `gemma` from inside the `build/` directory.
|
||||||
./gemma \
|
./gemma \
|
||||||
--tokenizer [tokenizer file] \
|
--tokenizer [tokenizer file] \
|
||||||
--weights [compressed weights file] \
|
--weights [compressed weights file] \
|
||||||
|
--weight_type [f32 or bf16 or sfp] \
|
||||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -245,8 +230,7 @@ Example invocation for the following configuration:
|
||||||
```sh
|
```sh
|
||||||
./gemma \
|
./gemma \
|
||||||
--tokenizer tokenizer.spm \
|
--tokenizer tokenizer.spm \
|
||||||
--weights 2b-it-sfp.sbs \
|
--weights 2b-it-sfp.sbs --weight_type sfp --model 2b-it
|
||||||
--model 2b-it
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### RecurrentGemma
|
### RecurrentGemma
|
||||||
|
|
@ -270,14 +254,12 @@ Step 1, and run the binary as follows:
|
||||||
|
|
||||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||||
|
|
||||||
The most common problem is that `cmake` was built with the wrong weight type and
|
The most common problem is that the `--weight_type` argument does not match that
|
||||||
`gemma` is attempting to load `bfloat16` weights (`2b-it`, `2b-pt`, `7b-it`,
|
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||||
`7b-pt`) using the default switched floating point (sfp) or vice versa. Revisit
|
|
||||||
step #3 and check that the `cmake` command used to build `gemma` was correct for
|
|
||||||
the weights that you downloaded.
|
|
||||||
|
|
||||||
In the future we will handle model format handling from compile time to runtime
|
Note that we have already moved weight type from a compile-time decision to a
|
||||||
to simplify this.
|
runtime argument. In a subsequent step, we plan to bake this information into
|
||||||
|
the weights.
|
||||||
|
|
||||||
**Problems building in Windows / Visual Studio**
|
**Problems building in Windows / Visual Studio**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
@ -44,6 +43,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,11 @@
|
||||||
|
|
||||||
#include "backprop/backward.h"
|
#include "backprop/backward.h"
|
||||||
|
|
||||||
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -29,7 +34,6 @@
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
|
|
@ -57,11 +61,11 @@ void CrossEntropyLossBackwardPassT(Model model,
|
||||||
// TODO(janwas): use CallFunctorForModel
|
// TODO(janwas): use CallFunctorForModel
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
CrossEntropyLossBackwardPass<ConfigGemma2B>(
|
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
|
||||||
prompt, weights, forward, grad, backward, pool);
|
prompt, weights, forward, grad, backward, pool);
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_TINY:
|
case Model::GEMMA_TINY:
|
||||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
|
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
|
||||||
prompt, weights, forward, grad, backward, pool);
|
prompt, weights, forward, grad, backward, pool);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,11 @@
|
||||||
|
|
||||||
#include "backprop/forward.h"
|
#include "backprop/forward.h"
|
||||||
|
|
||||||
|
#include "backprop/prompt.h"
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -29,7 +34,6 @@
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
|
|
@ -51,10 +55,10 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
||||||
// TODO(janwas): use CallFunctorForModel
|
// TODO(janwas): use CallFunctorForModel
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
return CrossEntropyLossForwardPass<ConfigGemma2B>(
|
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
|
||||||
prompt, weights, forward, pool);
|
forward, pool);
|
||||||
case Model::GEMMA_TINY:
|
case Model::GEMMA_TINY:
|
||||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
|
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
|
||||||
prompt, weights, forward, pool);
|
prompt, weights, forward, pool);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
|
|
|
||||||
|
|
@ -13,18 +13,23 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <iostream>
|
#include <stddef.h>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
#include "backprop/backward.h"
|
#include "backprop/backward.h"
|
||||||
#include "backprop/forward.h"
|
#include "backprop/forward.h"
|
||||||
#include "backprop/optimizer.h"
|
#include "backprop/optimizer.h"
|
||||||
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "gtest/gtest.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -35,11 +40,17 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
Model model_type = Model::GEMMA_TINY;
|
Model model_type = Model::GEMMA_TINY;
|
||||||
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
Type weight_type = Type::kF32;
|
||||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
ByteStorageT grad =
|
||||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||||
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
ByteStorageT grad_m =
|
||||||
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||||
|
ByteStorageT grad_v =
|
||||||
|
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||||
|
ByteStorageT forward =
|
||||||
|
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
||||||
|
ByteStorageT backward =
|
||||||
|
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
||||||
KVCache kv_cache = KVCache::Create(model_type);
|
KVCache kv_cache = KVCache::Create(model_type);
|
||||||
size_t max_tokens = 32;
|
size_t max_tokens = 32;
|
||||||
size_t max_generated_tokens = 16;
|
size_t max_generated_tokens = 16;
|
||||||
|
|
@ -47,7 +58,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
int verbosity = 0;
|
int verbosity = 0;
|
||||||
const auto accept_token = [](int) { return true; };
|
const auto accept_token = [](int) { return true; };
|
||||||
|
|
||||||
Gemma gemma(GemmaTokenizer(), model_type, pool);
|
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
|
||||||
|
|
||||||
const auto generate = [&](const std::vector<int>& prompt) {
|
const auto generate = [&](const std::vector<int>& prompt) {
|
||||||
std::vector<int> reply;
|
std::vector<int> reply;
|
||||||
|
|
@ -76,12 +87,14 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
return ok;
|
return ok;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandInitWeights(model_type, gemma.Weights(), pool, gen);
|
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
|
||||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
|
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
|
||||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
|
pool);
|
||||||
|
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
|
||||||
|
pool);
|
||||||
|
|
||||||
printf("Initial weights:\n");
|
printf("Initial weights:\n");
|
||||||
LogWeightStats(model_type, gemma.Weights());
|
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||||
|
|
||||||
constexpr size_t kBatchSize = 8;
|
constexpr size_t kBatchSize = 8;
|
||||||
const float alpha = 0.001f;
|
const float alpha = 0.001f;
|
||||||
|
|
@ -96,7 +109,8 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
size_t num_ok;
|
size_t num_ok;
|
||||||
for (; steps < 1000000; ++steps) {
|
for (; steps < 1000000; ++steps) {
|
||||||
std::mt19937 sgen(42);
|
std::mt19937 sgen(42);
|
||||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
|
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
|
||||||
|
pool);
|
||||||
float total_loss = 0.0f;
|
float total_loss = 0.0f;
|
||||||
num_ok = 0;
|
num_ok = 0;
|
||||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||||
|
|
@ -109,13 +123,13 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
||||||
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
|
||||||
gemma.Weights(), grad_m, grad_v, pool);
|
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
|
||||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||||
steps, total_loss, num_ok, kBatchSize);
|
steps, total_loss, num_ok, kBatchSize);
|
||||||
if (steps % 100 == 0) {
|
if (steps % 100 == 0) {
|
||||||
printf("Batch gradient:\n");
|
printf("Batch gradient:\n");
|
||||||
LogWeightStats(model_type, grad);
|
LogWeightStats(model_type, weight_type, grad);
|
||||||
}
|
}
|
||||||
if (total_loss < 0.5f) {
|
if (total_loss < 0.5f) {
|
||||||
break;
|
break;
|
||||||
|
|
@ -124,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
}
|
}
|
||||||
printf("Num steps: %zu\n", steps);
|
printf("Num steps: %zu\n", steps);
|
||||||
printf("Final weights:\n");
|
printf("Final weights:\n");
|
||||||
LogWeightStats(model_type, gemma.Weights());
|
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||||
EXPECT_LT(steps, 200);
|
EXPECT_LT(steps, 200);
|
||||||
EXPECT_EQ(num_ok, kBatchSize);
|
EXPECT_EQ(num_ok, kBatchSize);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -107,18 +107,20 @@ struct AdamUpdateT {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
void RandInitWeights(Model model_type, Type weight_type,
|
||||||
hwy::ThreadPool& pool,
|
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||||
std::mt19937& gen) {
|
std::mt19937& gen) {
|
||||||
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
|
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
|
||||||
|
pool, gen);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||||
float beta2, float epsilon, size_t t,
|
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||||
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
|
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
|
||||||
weights, grad_m, grad_v, pool);
|
beta1, beta2, epsilon, t, weights, grad_m,
|
||||||
|
grad_v, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -19,16 +19,16 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
void RandInitWeights(Model model_type, Type weight_type,
|
||||||
hwy::ThreadPool& pool, std::mt19937& gen);
|
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||||
|
std::mt19937& gen);
|
||||||
|
|
||||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||||
float beta2, float epsilon, size_t t,
|
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ int main(int argc, char** argv) {
|
||||||
gcpp::PinWorkersToCores(pool);
|
gcpp::PinWorkersToCores(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||||
|
|
||||||
const std::string& prompt = prompt_args.prompt;
|
const std::string& prompt = prompt_args.prompt;
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ int main(int argc, char** argv) {
|
||||||
hwy::ThreadPool pool(num_threads);
|
hwy::ThreadPool pool(num_threads);
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||||
size_t pos = 0; // KV Cache position
|
size_t pos = 0; // KV Cache position
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -280,7 +280,7 @@ int main(int argc, char** argv) {
|
||||||
gcpp::PinWorkersToCores(pool);
|
gcpp::PinWorkersToCores(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||||
|
|
||||||
if (!benchmark_args.goldens.path.empty()) {
|
if (!benchmark_args.goldens.path.empty()) {
|
||||||
|
|
|
||||||
|
|
@ -64,4 +64,28 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
return kErrorMessageBuffer;
|
return kErrorMessageBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* ParseType(const std::string& type_string, Type& type) {
|
||||||
|
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
|
||||||
|
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
|
||||||
|
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
|
||||||
|
static char kErrorMessageBuffer[kNum * 8 + 100] =
|
||||||
|
"Invalid or missing type, need to specify one of ";
|
||||||
|
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||||
|
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
|
||||||
|
strcat(kErrorMessageBuffer, ", "); // NOLINT
|
||||||
|
}
|
||||||
|
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
|
||||||
|
strcat(kErrorMessageBuffer, "."); // NOLINT
|
||||||
|
std::string type_lc = type_string;
|
||||||
|
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
for (size_t i = 0; i < kNum; i++) {
|
||||||
|
if (kStrings[i] == type_lc) {
|
||||||
|
type = kTypes[i];
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kErrorMessageBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
107
gemma/common.h
107
gemma/common.h
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "compression/compress.h"
|
||||||
#include "gemma/configs.h" // IWYU pragma: export
|
#include "gemma/configs.h" // IWYU pragma: export
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // ConvertScalarTo
|
#include "hwy/base.h" // ConvertScalarTo
|
||||||
|
|
@ -37,56 +38,86 @@ ByteStorageT AllocateSizeof() {
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
||||||
|
|
||||||
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
||||||
// Returns the return value of Func<T>().operator() called with `args`, where
|
// Tensor types for loading weights.
|
||||||
// `T` is selected based on `model`.
|
enum class Type { kF32, kBF16, kSFP };
|
||||||
|
|
||||||
|
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
||||||
|
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
||||||
|
// but can also be called directly when FuncT does not actually use TWeight.
|
||||||
//
|
//
|
||||||
// This is used to implement type-erased functions such as
|
// Note that a T prefix indicates a concrete type template argument, whereas a
|
||||||
// LoadCompressedWeights, which can be called from other .cc files, by calling a
|
// T suffix indicates the argument is itself a template.
|
||||||
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
|
|
||||||
// be a functor because function templates cannot be passed as a template
|
|
||||||
// template argument, and we prefer to avoid the overhead of std::function.
|
|
||||||
//
|
//
|
||||||
// This function avoids having to update all call sites when we extend `Model`.
|
// `FuncT` must be a functor because function templates cannot be passed as a
|
||||||
template <template <typename Config> class Func, typename... Args>
|
// template template argument, and we prefer to avoid the overhead of
|
||||||
decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
|
// std::function.
|
||||||
|
template <typename TWeight, template <typename TConfig> class FuncT,
|
||||||
|
typename... TArgs>
|
||||||
|
decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_TINY:
|
case Model::GEMMA_TINY:
|
||||||
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...);
|
return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
return Func<ConfigGemma2B>()(std::forward<Args>(args)...);
|
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
return Func<ConfigGemma7B>()(std::forward<Args>(args)...);
|
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
case Model::GRIFFIN_2B:
|
case Model::GRIFFIN_2B:
|
||||||
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...);
|
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Like CallFunctorForModel, but for SIMD function templates. This is a macro
|
// Returns the return value of FuncT<TConfig>().operator()(args),
|
||||||
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
// where `TConfig` is selected based on `model` and `weight`.
|
||||||
// normal function argument.
|
|
||||||
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \
|
// This makes it easy to extend `Model` or `Type` without updating callers.
|
||||||
|
//
|
||||||
|
// Usage example: LoadWeights is type-erased so that it can be called from other
|
||||||
|
// .cc files. It uses this function to call the appropriate instantiation of a
|
||||||
|
// template functor LoadCompressedWeightsT<TConfig>.
|
||||||
|
template <template <typename TConfig> class FuncT, typename... TArgs>
|
||||||
|
decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
|
TArgs&&... args) {
|
||||||
|
switch (weight) {
|
||||||
|
case Type::kF32:
|
||||||
|
return CallForModel<float, FuncT, TArgs...>( //
|
||||||
|
model, std::forward<TArgs>(args)...);
|
||||||
|
case Type::kBF16:
|
||||||
|
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
|
||||||
|
model, std::forward<TArgs>(args)...);
|
||||||
|
case Type::kSFP:
|
||||||
|
return CallForModel<SfpStream, FuncT, TArgs...>(
|
||||||
|
model, std::forward<TArgs>(args)...);
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||||
|
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
|
||||||
|
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
|
||||||
switch (MODEL) { \
|
switch (MODEL) { \
|
||||||
case Model::GEMMA_TINY: { \
|
case Model::GEMMA_TINY: { \
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
|
||||||
ARGS; \
|
ARGS; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case Model::GEMMA_2B: { \
|
case Model::GEMMA_2B: { \
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
|
||||||
ARGS; \
|
ARGS; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case Model::GEMMA_7B: { \
|
case Model::GEMMA_7B: { \
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
|
||||||
ARGS; \
|
ARGS; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case Model::GRIFFIN_2B: { \
|
case Model::GRIFFIN_2B: { \
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
|
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
|
||||||
ARGS; \
|
ARGS; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
|
|
@ -94,10 +125,42 @@ decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
|
||||||
|
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
||||||
|
// normal function argument. MODEL and WEIGHT are enums.
|
||||||
|
#define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
|
||||||
|
switch (WEIGHT) { \
|
||||||
|
case Type::kF32: \
|
||||||
|
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
|
||||||
|
break; \
|
||||||
|
case Type::kBF16: \
|
||||||
|
GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
|
||||||
|
break; \
|
||||||
|
case Type::kSFP: \
|
||||||
|
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
|
||||||
|
}
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
// Thread-hostile.
|
// Thread-hostile.
|
||||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
Model& model, ModelTraining& training);
|
Model& model, ModelTraining& training);
|
||||||
|
const char* ParseType(const std::string& type_string, Type& type);
|
||||||
|
|
||||||
|
static inline const char* StringFromType(Type type) {
|
||||||
|
switch (type) {
|
||||||
|
case Type::kF32:
|
||||||
|
return "f32";
|
||||||
|
case Type::kBF16:
|
||||||
|
return "bf16";
|
||||||
|
case Type::kSFP:
|
||||||
|
return "sfp";
|
||||||
|
default:
|
||||||
|
return "?";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||||
#if HWY_COMPILER_GCC_ACTUAL
|
#if HWY_COMPILER_GCC_ACTUAL
|
||||||
|
|
|
||||||
|
|
@ -62,14 +62,16 @@ struct Args : public ArgsBase<Args> {
|
||||||
ChooseNumThreads();
|
ChooseNumThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Model ModelType() const { return model_type; }
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
ModelTraining model_training;
|
ModelTraining model_training;
|
||||||
const char* parse_result =
|
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
model_training)) {
|
||||||
if (parse_result) return parse_result;
|
return err;
|
||||||
|
}
|
||||||
|
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
return "Missing --weights flag, a file for the uncompressed model.";
|
return "Missing --weights flag, a file for the uncompressed model.";
|
||||||
}
|
}
|
||||||
|
|
@ -86,7 +88,7 @@ struct Args : public ArgsBase<Args> {
|
||||||
Path weights; // uncompressed weights file location
|
Path weights; // uncompressed weights file location
|
||||||
Path compressed_weights; // compressed weights file location
|
Path compressed_weights; // compressed weights file location
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
Model model_type;
|
std::string weight_type_str;
|
||||||
size_t num_threads;
|
size_t num_threads;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -101,6 +103,9 @@ struct Args : public ArgsBase<Args> {
|
||||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
|
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||||
|
" Required argument.");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
|
|
@ -110,6 +115,14 @@ struct Args : public ArgsBase<Args> {
|
||||||
"number of suupported concurrent threads.",
|
"number of suupported concurrent threads.",
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Uninitialized before Validate, must call after that.
|
||||||
|
gcpp::Model ModelType() const { return model_type_; }
|
||||||
|
gcpp::Type WeightType() const { return weight_type_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Model model_type_;
|
||||||
|
Type weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void ShowHelp(gcpp::Args& args) {
|
void ShowHelp(gcpp::Args& args) {
|
||||||
|
|
@ -132,7 +145,7 @@ namespace HWY_NAMESPACE {
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void CompressWeights(const Path& weights_path,
|
void CompressWeights(const Path& weights_path,
|
||||||
const Path& compressed_weights_path, Model model_type,
|
const Path& compressed_weights_path, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
Type weight_type, hwy::ThreadPool& pool) {
|
||||||
if (!weights_path.Exists()) {
|
if (!weights_path.Exists()) {
|
||||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
weights_path.path.c_str());
|
weights_path.path.c_str());
|
||||||
|
|
@ -147,7 +160,7 @@ void CompressWeights(const Path& weights_path,
|
||||||
// Get weights, compress, and store.
|
// Get weights, compress, and store.
|
||||||
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
||||||
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
|
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
|
||||||
weights_path, model_type, pool, scale_for_compression);
|
weights_path, model_type, weight_type, pool, scale_for_compression);
|
||||||
WeightsF<TConfig>* weights =
|
WeightsF<TConfig>* weights =
|
||||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
Compressor compressor(pool);
|
Compressor compressor(pool);
|
||||||
|
|
@ -169,9 +182,10 @@ namespace gcpp {
|
||||||
void Run(Args& args) {
|
void Run(Args& args) {
|
||||||
hwy::ThreadPool pool(args.num_threads);
|
hwy::ThreadPool pool(args.num_threads);
|
||||||
const Model model_type = args.ModelType();
|
const Model model_type = args.ModelType();
|
||||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
const Type weight_type = args.WeightType();
|
||||||
model_type, CompressWeights,
|
GEMMA_EXPORT_AND_DISPATCH(
|
||||||
(args.weights, args.compressed_weights, model_type, pool));
|
model_type, weight_type, CompressWeights,
|
||||||
|
(args.weights, args.compressed_weights, model_type, weight_type, pool));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
#include "compression/compress.h" // SfpStream
|
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -42,18 +41,10 @@ namespace gcpp {
|
||||||
#define GEMMA_MAX_THREADS 128
|
#define GEMMA_MAX_THREADS 128
|
||||||
#endif // !GEMMA_MAX_THREADS
|
#endif // !GEMMA_MAX_THREADS
|
||||||
|
|
||||||
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
|
||||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
|
||||||
#ifndef GEMMA_WEIGHT_T
|
|
||||||
#define GEMMA_WEIGHT_T SfpStream
|
|
||||||
#endif // !GEMMA_WEIGHT_T
|
|
||||||
|
|
||||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||||
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||||
|
|
||||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
|
||||||
|
|
||||||
using EmbedderInputT = hwy::bfloat16_t;
|
using EmbedderInputT = hwy::bfloat16_t;
|
||||||
|
|
||||||
enum class LayerAttentionType {
|
enum class LayerAttentionType {
|
||||||
|
|
@ -82,7 +73,10 @@ constexpr size_t NumLayersOfTypeBefore(
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TWeight>
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B {
|
||||||
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||||
|
|
@ -111,10 +105,12 @@ struct ConfigGemma7B {
|
||||||
static constexpr bool kUseLocalAttention = false;
|
static constexpr bool kUseLocalAttention = false;
|
||||||
static constexpr bool kInterleaveQKV = true;
|
static constexpr bool kInterleaveQKV = true;
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TWeight>
|
||||||
struct ConfigGemma2B {
|
struct ConfigGemma2B {
|
||||||
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||||
|
|
@ -143,10 +139,12 @@ struct ConfigGemma2B {
|
||||||
static constexpr bool kUseLocalAttention = false;
|
static constexpr bool kUseLocalAttention = false;
|
||||||
static constexpr bool kInterleaveQKV = true;
|
static constexpr bool kInterleaveQKV = true;
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TWeight>
|
||||||
struct ConfigGemmaTiny {
|
struct ConfigGemmaTiny {
|
||||||
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = 32;
|
static constexpr int kSeqLen = 32;
|
||||||
static constexpr int kVocabSize = 16;
|
static constexpr int kVocabSize = 16;
|
||||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||||
|
|
@ -175,10 +173,12 @@ struct ConfigGemmaTiny {
|
||||||
static constexpr bool kUseLocalAttention = false;
|
static constexpr bool kUseLocalAttention = false;
|
||||||
static constexpr bool kInterleaveQKV = true;
|
static constexpr bool kInterleaveQKV = true;
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TWeight>
|
||||||
struct ConfigGriffin2B {
|
struct ConfigGriffin2B {
|
||||||
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||||
// window.
|
// window.
|
||||||
static constexpr int kSeqLen = 2048;
|
static constexpr int kSeqLen = 2048;
|
||||||
|
|
@ -235,7 +235,6 @@ struct ConfigGriffin2B {
|
||||||
static constexpr bool kUseLocalAttention = true;
|
static constexpr bool kUseLocalAttention = true;
|
||||||
static constexpr bool kInterleaveQKV = false;
|
static constexpr bool kInterleaveQKV = false;
|
||||||
static constexpr int kNumTensorScales = 140;
|
static constexpr int kNumTensorScales = 140;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,20 @@
|
||||||
|
|
||||||
#include "gemma/cross_entropy.h"
|
#include "gemma/cross_entropy.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <functional>
|
||||||
#include <regex> // NOLINT
|
#include <regex> // NOLINT
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -63,7 +70,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
auto stream_token = [](int, float) { return true; };
|
auto stream_token = [](int, float) { return true; };
|
||||||
auto accept_token = [](int) { return true; };
|
auto accept_token = [](int) { return true; };
|
||||||
|
|
||||||
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
|
// TWeight is unused, but we have to pass it to Config*.
|
||||||
|
const int vocab_size =
|
||||||
|
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
|
||||||
float cross_entropy = std::log(vocab_size); // first token
|
float cross_entropy = std::log(vocab_size); // first token
|
||||||
size_t pos = 1;
|
size_t pos = 1;
|
||||||
std::function<int(const float*, size_t)> sample_token =
|
std::function<int(const float*, size_t)> sample_token =
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
|
||||||
|
|
@ -147,8 +147,10 @@ struct CreateKVCache {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
KVCache KVCache::Create(Model type) {
|
KVCache KVCache::Create(Model model_type) {
|
||||||
return CallFunctorForModel<CreateKVCache>(type);
|
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||||
|
// use TConfig::Weight.
|
||||||
|
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
class GemmaTokenizer::Impl {
|
class GemmaTokenizer::Impl {
|
||||||
|
|
@ -727,7 +729,7 @@ Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
const ByteStorageT& decode_u8,
|
const ByteStorageT& decode_u8,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
|
|
@ -871,23 +873,31 @@ struct AllocateDecode {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool)
|
Type weight_type, hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(tokenizer_path), model_type_(model_type) {
|
: pool_(pool),
|
||||||
weights_u8_ = LoadWeights(weights, model_type, pool);
|
tokenizer_(tokenizer_path),
|
||||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
model_type_(model_type),
|
||||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
weight_type_(weight_type) {
|
||||||
|
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
|
||||||
|
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||||
|
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
hwy::ThreadPool& pool)
|
hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
|
: pool_(pool),
|
||||||
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
tokenizer_(std::move(tokenizer)),
|
||||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
model_type_(model_type),
|
||||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
weight_type_(weight_type) {
|
||||||
|
weights_u8_ =
|
||||||
|
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||||
|
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||||
|
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
CallFunctorForModel<DeleteLayersPtrs>(model_type_, weights_u8_);
|
CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
|
||||||
|
weights_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
|
|
@ -896,8 +906,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
GEMMA_EXPORT_AND_DISPATCH(
|
||||||
model_type_, Generate,
|
model_type_, weight_type_, GenerateT,
|
||||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
||||||
kv_cache, pool_, timing_info, layers_output));
|
kv_cache, pool_, timing_info, layers_output));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,10 +107,11 @@ using LayersOutputT =
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool);
|
Type weight_type, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
// Allocates weights, caller is responsible for filling them.
|
// Allocates weights, caller is responsible for filling them.
|
||||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
|
hwy::ThreadPool& pool);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
Model ModelType() const { return model_type_; }
|
Model ModelType() const { return model_type_; }
|
||||||
|
|
@ -136,6 +137,7 @@ class Gemma {
|
||||||
ByteStorageT prefill_u8_;
|
ByteStorageT prefill_u8_;
|
||||||
ByteStorageT decode_u8_;
|
ByteStorageT decode_u8_;
|
||||||
Model model_type_;
|
Model model_type_;
|
||||||
|
Type weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// DEPRECATED, call Gemma::Generate directly.
|
// DEPRECATED, call Gemma::Generate directly.
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
: weights("./2b-it-mqa.sbs"),
|
: weights("./2b-it-mqa.sbs"),
|
||||||
tokenizer("./tokenizer.spm"),
|
tokenizer("./tokenizer.spm"),
|
||||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||||
model_type(gcpp::Model::GEMMA_2B),
|
model(tokenizer, weights, model_type, weight_type, pool) {
|
||||||
model(tokenizer, weights, model_type, pool) {
|
|
||||||
KVCache kv_cache = KVCache::Create(model_type);
|
KVCache kv_cache = KVCache::Create(model_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -96,7 +95,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
gcpp::Path tokenizer;
|
gcpp::Path tokenizer;
|
||||||
gcpp::KVCache kv_cache;
|
gcpp::KVCache kv_cache;
|
||||||
hwy::ThreadPool pool;
|
hwy::ThreadPool pool;
|
||||||
gcpp::Model model_type = {};
|
gcpp::Model model_type = gcpp::Model::GEMMA_2B;
|
||||||
|
gcpp::Type weight_type = gcpp::Type::kSFP;
|
||||||
gcpp::Gemma model;
|
gcpp::Gemma model;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||||
<< "Compiled config : " << CompiledConfig() << "\n"
|
<< "Compiled config : " << CompiledConfig() << "\n"
|
||||||
<< "Weight Type : "
|
<< "Weight Type : "
|
||||||
<< gcpp::TypeName(gcpp::GemmaWeightT()) << "\n"
|
<< gcpp::StringFromType(loader.WeightType()) << "\n"
|
||||||
<< "EmbedderInput Type : "
|
<< "EmbedderInput Type : "
|
||||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
||||||
}
|
}
|
||||||
|
|
@ -251,8 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PinWorkersToCores(pool);
|
PinWorkersToCores(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
|
|
||||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
||||||
|
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
|
|
@ -278,7 +277,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(
|
ReplGemma(
|
||||||
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
|
model, loader.ModelTrainingType(), kv_cache, pool, inference,
|
||||||
|
app.verbosity,
|
||||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -157,8 +157,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PinWorkersToCores(pool);
|
PinWorkersToCores(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||||
|
|
||||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||||
|
|
||||||
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
|
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
|
||||||
|
|
|
||||||
|
|
@ -173,10 +173,11 @@ struct LoadRawWeightsT {
|
||||||
#undef SCALE_WEIGHTS
|
#undef SCALE_WEIGHTS
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool, bool scale_for_compression) {
|
Type weight_type, hwy::ThreadPool& pool,
|
||||||
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool,
|
bool scale_for_compression) {
|
||||||
scale_for_compression);
|
return CallForModelAndWeight<LoadRawWeightsT>(
|
||||||
|
model_type, weight_type, weights, pool, scale_for_compression);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -227,17 +228,18 @@ struct LoadCompressedWeightsT {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model,
|
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
Type weight_type, hwy::ThreadPool& pool) {
|
||||||
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool);
|
return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type,
|
||||||
|
weights, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
Type weight_type, hwy::ThreadPool& pool) {
|
||||||
if constexpr (kWeightsAreCompressed) {
|
if constexpr (kWeightsAreCompressed) {
|
||||||
return LoadCompressedWeights(weights, model, pool);
|
return LoadCompressedWeights(weights, model_type, weight_type, pool);
|
||||||
} else {
|
} else {
|
||||||
return LoadRawWeights(weights, model, pool,
|
return LoadRawWeights(weights, model_type, weight_type, pool,
|
||||||
/*scale_for_compression=*/false);
|
/*scale_for_compression=*/false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -274,8 +276,9 @@ struct LogWeightStatsT {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
|
void LogWeightStats(gcpp::Model model_type, Type weight_type,
|
||||||
CallFunctorForModel<LogWeightStatsT>(model, weights);
|
const ByteStorageT& weights) {
|
||||||
|
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -129,21 +129,17 @@ using WeightsF = Weights<float, TConfig>;
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Compressed
|
// Compressed
|
||||||
|
|
||||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
|
|
||||||
// not yet support smaller compressed types, or require at least bf16. When
|
|
||||||
// weights are f32, we also want such tensors to be f32.
|
|
||||||
template <class TConfig>
|
|
||||||
using WeightF32OrBF16T =
|
|
||||||
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
|
|
||||||
hwy::bfloat16_t>;
|
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct CompressedLayer {
|
struct CompressedLayer {
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
using TLayer = gcpp::LayerF<TConfig>;
|
using TLayer = gcpp::LayerF<TConfig>;
|
||||||
using WeightT = typename TConfig::WeightT;
|
using Weight = typename TConfig::Weight;
|
||||||
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>;
|
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
||||||
|
// do not yet support smaller compressed types, or require at least bf16. When
|
||||||
|
// weights are f32, we also want such tensors to be f32.
|
||||||
|
using WeightF32OrBF16 =
|
||||||
|
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
||||||
|
|
||||||
static constexpr size_t kHeads = TLayer::kHeads;
|
static constexpr size_t kHeads = TLayer::kHeads;
|
||||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
||||||
|
|
@ -166,29 +162,29 @@ struct CompressedLayer {
|
||||||
|
|
||||||
union {
|
union {
|
||||||
struct {
|
struct {
|
||||||
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
|
ArrayT<Weight, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||||
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
|
ArrayT<Weight, kQKVEinsumWSize> qkv_einsum_w;
|
||||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_x_w;
|
||||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_y_w;
|
||||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
|
||||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||||
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
||||||
ArrayT<float, kGriffinDim> conv_biases;
|
ArrayT<float, kGriffinDim> conv_biases;
|
||||||
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||||
ArrayT<float, kGriffinDim> a;
|
ArrayT<float, kGriffinDim> a;
|
||||||
} griffin;
|
} griffin;
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
ArrayT<Weight, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||||
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
|
||||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
// We don't yet have an RMSNorm that accepts all Weight.
|
||||||
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
||||||
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
|
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
|
||||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
|
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
|
||||||
|
|
@ -241,7 +237,7 @@ template <class TConfig>
|
||||||
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||||
WeightsF<TConfig>>;
|
WeightsF<TConfig>>;
|
||||||
|
|
||||||
// Call via CallFunctorForModel.
|
// TODO: can we use TConfig::Weight instead of T?
|
||||||
template <typename T, typename TConfig>
|
template <typename T, typename TConfig>
|
||||||
struct AllocateWeights {
|
struct AllocateWeights {
|
||||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
|
@ -335,14 +331,15 @@ class WeightsWrapper {
|
||||||
};
|
};
|
||||||
|
|
||||||
// For use by compress_weights.cc.
|
// For use by compress_weights.cc.
|
||||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool, bool scale_for_compression);
|
Type weight_type, hwy::ThreadPool& pool,
|
||||||
|
bool scale_for_compression);
|
||||||
|
|
||||||
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
|
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
|
||||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool);
|
Type weight_type, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
void LogWeightStats(Model model, const ByteStorageT& weights);
|
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Iterators
|
// Iterators
|
||||||
|
|
|
||||||
46
util/app.h
46
util/app.h
|
|
@ -21,19 +21,17 @@
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX
|
||||||
#include <sched.h>
|
#include <sched.h>
|
||||||
|
|
||||||
#include <cctype>
|
|
||||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
|
||||||
#include <string>
|
|
||||||
#endif // HWY_OS_LINUX
|
#endif // HWY_OS_LINUX
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::clamp
|
#include <algorithm> // std::clamp
|
||||||
|
#include <string>
|
||||||
#include <thread> // NOLINT>
|
#include <thread> // NOLINT>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
|
|
@ -49,14 +47,10 @@ static inline const char* CompiledConfig() {
|
||||||
return "msan";
|
return "msan";
|
||||||
} else if (HWY_IS_TSAN) {
|
} else if (HWY_IS_TSAN) {
|
||||||
return "tsan";
|
return "tsan";
|
||||||
#if defined(HWY_IS_HWASAN)
|
|
||||||
} else if (HWY_IS_HWASAN) {
|
} else if (HWY_IS_HWASAN) {
|
||||||
return "hwasan";
|
return "hwasan";
|
||||||
#endif
|
|
||||||
#if defined(HWY_IS_UBSAN)
|
|
||||||
} else if (HWY_IS_UBSAN) {
|
} else if (HWY_IS_UBSAN) {
|
||||||
return "ubsan";
|
return "ubsan";
|
||||||
#endif
|
|
||||||
} else if (HWY_IS_DEBUG_BUILD) {
|
} else if (HWY_IS_DEBUG_BUILD) {
|
||||||
return "dbg";
|
return "dbg";
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -172,15 +166,15 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
gcpp::Model ModelType() const { return model_type; }
|
|
||||||
|
|
||||||
gcpp::ModelTraining ModelTraining() const { return model_training; }
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
const char* parse_result =
|
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
model_training_)) {
|
||||||
if (parse_result) return parse_result;
|
return err;
|
||||||
|
}
|
||||||
|
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
}
|
}
|
||||||
|
|
@ -209,8 +203,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
Path weights; // weights file location
|
Path weights; // weights file location
|
||||||
Path compressed_weights;
|
Path compressed_weights;
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
Model model_type;
|
std::string weight_type_str;
|
||||||
enum ModelTraining model_training;
|
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -227,9 +220,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
|
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||||
|
" Required argument.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Uninitialized before Validate, must call after that.
|
||||||
|
gcpp::Model ModelType() const { return model_type_; }
|
||||||
|
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||||
|
gcpp::Type WeightType() const { return weight_type_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Model model_type_;
|
||||||
|
ModelTraining model_training_;
|
||||||
|
Type weight_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline Gemma CreateGemma(const LoaderArgs& loader,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
|
||||||
|
loader.WeightType(), pool);
|
||||||
|
}
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue