mirror of https://github.com/google/gemma.cpp.git
Support all weight types in a single binary.
This changes the command line flags, but the default value retains the previous behavior. Also add a CreateGemma helper to enable extra args without interface changes. PiperOrigin-RevId: 641266411
This commit is contained in:
parent
24db2ff725
commit
f9b390b134
|
|
@ -116,6 +116,7 @@ cc_library(
|
|||
"gemma/cross_entropy.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":gemma_lib",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -76,17 +76,6 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
# Allowable types for WEIGHT_TYPE:
|
||||
# float - slow, not recommended
|
||||
# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway
|
||||
# SfpStream - 8-bit switched floating point (recommended)
|
||||
# NuqStream - experimental, work-in-progress
|
||||
option(WEIGHT_TYPE "Set weight type" "")
|
||||
|
||||
if (WEIGHT_TYPE)
|
||||
add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE})
|
||||
endif()
|
||||
|
||||
FetchContent_GetProperties(sentencepiece)
|
||||
|
||||
## Library Target
|
||||
|
|
|
|||
|
|
@ -105,18 +105,12 @@ the resulting file as `--weights` and the desired .sbs name as the
|
|||
There are several compile-time flags to be aware of (note these may or may not
|
||||
be exposed to the build system):
|
||||
|
||||
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
|
||||
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
|
||||
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
|
||||
enable for higher-fidelity (but slower) bfloat16 support. This is defined in
|
||||
`gemma.h`.
|
||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||
through `CMakeLists.txt` yet.
|
||||
|
||||
In the medium term both of these will likely be deprecated in favor of handling
|
||||
options at runtime - allowing for multiple weight compression schemes in a single
|
||||
build and dynamically resizes the KV cache as needed.
|
||||
In the medium term this will likely be deprecated in favor of handling options
|
||||
at runtime - dynamically resizing the KV cache as needed.
|
||||
|
||||
## Using gemma.cpp as a Library (Advanced)
|
||||
|
||||
|
|
|
|||
52
README.md
52
README.md
|
|
@ -138,33 +138,16 @@ convenient directory location (e.g. the `build/` directory in this repo).
|
|||
The build system uses [CMake](https://cmake.org/). To build the gemma inference
|
||||
runtime, create a build directory and generate the build files using `cmake`
|
||||
from the top-level project directory. Note if you previous ran `cmake` and are
|
||||
re-running with a different setting, be sure to clean out the `build/` directory
|
||||
with `rm -rf build/*` (warning this will delete any other files in the `build/`
|
||||
directory.
|
||||
|
||||
For the 8-bit switched floating point weights (sfp), run cmake with no options:
|
||||
re-running with a different setting, be sure to delete all files in the `build/`
|
||||
directory with `rm -rf build/*`.
|
||||
|
||||
#### Unix-like Platforms
|
||||
```sh
|
||||
cmake -B build
|
||||
```
|
||||
|
||||
**or** if you downloaded bfloat16 weights (any model *without* `-sfp` in the
|
||||
name), instead of running cmake with no options as above, run cmake with
|
||||
WEIGHT_TYPE set to [highway's](https://github.com/google/highway)
|
||||
`hwy::bfloat16_t` type. Alternatively, you can also add
|
||||
`-DGEMMA_WEIGHT_T=hwy::bfloat16_t` to the C++ compiler flags.
|
||||
|
||||
We intend to soon support all weight types without requiring extra flags. Note
|
||||
that we recommend using `-sfp` weights instead of bfloat16 for faster inference.
|
||||
|
||||
```sh
|
||||
cmake -B build -DWEIGHT_TYPE=hwy::bfloat16_t
|
||||
```
|
||||
|
||||
After running whichever of the above `cmake` invocations that is appropriate for
|
||||
your weights, you can enter the `build/` directory and run `make` to build the
|
||||
`./gemma` executable:
|
||||
After running `cmake`, you can enter the `build/` directory and run `make` to
|
||||
build the `./gemma` executable:
|
||||
|
||||
```sh
|
||||
# Configure `build` directory
|
||||
|
|
@ -221,11 +204,12 @@ You can now run `gemma` from inside the `build/` directory.
|
|||
|
||||
`gemma` has the following required arguments:
|
||||
|
||||
| Argument | Description | Example value |
|
||||
| ------------- | ---------------------------- | -------------------------- |
|
||||
| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) |
|
||||
| `--weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) |
|
||||
| `--tokenizer` | The tokenizer file. | `tokenizer.spm` |
|
||||
Argument | Description | Example value
|
||||
--------------- | ---------------------------- | -----------------------
|
||||
`--model` | The model type. | `2b-it` ... (see below)
|
||||
`--weights` | The compressed weights file. | `2b-it-sfp.sbs`
|
||||
`--weight_type` | The compressed weight type. | `sfp`
|
||||
`--tokenizer` | The tokenizer file. | `tokenizer.spm`
|
||||
|
||||
`gemma` is invoked as:
|
||||
|
||||
|
|
@ -233,6 +217,7 @@ You can now run `gemma` from inside the `build/` directory.
|
|||
./gemma \
|
||||
--tokenizer [tokenizer file] \
|
||||
--weights [compressed weights file] \
|
||||
--weight_type [f32 or bf16 or sfp] \
|
||||
--model [2b-it or 2b-pt or 7b-it or 7b-pt or ...]
|
||||
```
|
||||
|
||||
|
|
@ -245,8 +230,7 @@ Example invocation for the following configuration:
|
|||
```sh
|
||||
./gemma \
|
||||
--tokenizer tokenizer.spm \
|
||||
--weights 2b-it-sfp.sbs \
|
||||
--model 2b-it
|
||||
--weights 2b-it-sfp.sbs --weight_type sfp --model 2b-it
|
||||
```
|
||||
|
||||
### RecurrentGemma
|
||||
|
|
@ -270,14 +254,12 @@ Step 1, and run the binary as follows:
|
|||
|
||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||
|
||||
The most common problem is that `cmake` was built with the wrong weight type and
|
||||
`gemma` is attempting to load `bfloat16` weights (`2b-it`, `2b-pt`, `7b-it`,
|
||||
`7b-pt`) using the default switched floating point (sfp) or vice versa. Revisit
|
||||
step #3 and check that the `cmake` command used to build `gemma` was correct for
|
||||
the weights that you downloaded.
|
||||
The most common problem is that the `--weight_type` argument does not match that
|
||||
of the model file. Revisit step #3 and check which weights you downloaded.
|
||||
|
||||
In the future we will handle model format handling from compile time to runtime
|
||||
to simplify this.
|
||||
Note that we have already moved weight type from a compile-time decision to a
|
||||
runtime argument. In a subsequent step, we plan to bake this information into
|
||||
the weights.
|
||||
|
||||
**Problems building in Windows / Visual Studio**
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
|
@ -44,6 +43,7 @@
|
|||
#endif
|
||||
|
||||
#include "gemma/ops.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@
|
|||
|
||||
#include "backprop/backward.h"
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -29,7 +34,6 @@
|
|||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <typename TConfig>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
|
|
@ -57,11 +61,11 @@ void CrossEntropyLossBackwardPassT(Model model,
|
|||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CrossEntropyLossBackwardPass<ConfigGemma2B>(
|
||||
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
|
||||
prompt, weights, forward, grad, backward, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny>(
|
||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
|
||||
prompt, weights, forward, grad, backward, pool);
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@
|
|||
|
||||
#include "backprop/forward.h"
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -29,7 +34,6 @@
|
|||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <typename TConfig>
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
|
|
@ -51,10 +55,10 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
|||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return CrossEntropyLossForwardPass<ConfigGemma2B>(
|
||||
prompt, weights, forward, pool);
|
||||
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(prompt, weights,
|
||||
forward, pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny>(
|
||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
|
||||
prompt, weights, forward, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
|
|
|
|||
|
|
@ -13,18 +13,23 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <stddef.h>
|
||||
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -35,11 +40,17 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
std::mt19937 gen(42);
|
||||
|
||||
Model model_type = Model::GEMMA_TINY;
|
||||
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
Type weight_type = Type::kF32;
|
||||
ByteStorageT grad =
|
||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||
ByteStorageT grad_m =
|
||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||
ByteStorageT grad_v =
|
||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||
ByteStorageT forward =
|
||||
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
||||
ByteStorageT backward =
|
||||
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
||||
KVCache kv_cache = KVCache::Create(model_type);
|
||||
size_t max_tokens = 32;
|
||||
size_t max_generated_tokens = 16;
|
||||
|
|
@ -47,7 +58,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
int verbosity = 0;
|
||||
const auto accept_token = [](int) { return true; };
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), model_type, pool);
|
||||
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
|
|
@ -76,12 +87,14 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
return ok;
|
||||
};
|
||||
|
||||
RandInitWeights(model_type, gemma.Weights(), pool, gen);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
|
||||
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
|
||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
|
||||
pool);
|
||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
|
||||
pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
const float alpha = 0.001f;
|
||||
|
|
@ -96,7 +109,8 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
|
||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
|
||||
pool);
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
|
|
@ -109,13 +123,13 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
|
||||
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
LogWeightStats(model_type, grad);
|
||||
LogWeightStats(model_type, weight_type, grad);
|
||||
}
|
||||
if (total_loss < 0.5f) {
|
||||
break;
|
||||
|
|
@ -124,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||
EXPECT_LT(steps, 200);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -107,18 +107,20 @@ struct AdamUpdateT {
|
|||
|
||||
} // namespace
|
||||
|
||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||
hwy::ThreadPool& pool,
|
||||
void RandInitWeights(Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) {
|
||||
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
|
||||
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
|
||||
pool, gen);
|
||||
}
|
||||
|
||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
|
||||
weights, grad_m, grad_v, pool);
|
||||
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
|
||||
beta1, beta2, epsilon, t, weights, grad_m,
|
||||
grad_v, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -19,16 +19,16 @@
|
|||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||
hwy::ThreadPool& pool, std::mt19937& gen);
|
||||
void RandInitWeights(Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen);
|
||||
|
||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::PinWorkersToCores(pool);
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
const std::string& prompt = prompt_args.prompt;
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ int main(int argc, char** argv) {
|
|||
hwy::ThreadPool pool(num_threads);
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
||||
|
|
|
|||
|
|
@ -280,7 +280,7 @@ int main(int argc, char** argv) {
|
|||
gcpp::PinWorkersToCores(pool);
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
if (!benchmark_args.goldens.path.empty()) {
|
||||
|
|
|
|||
|
|
@ -64,4 +64,28 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
|||
return kErrorMessageBuffer;
|
||||
}
|
||||
|
||||
const char* ParseType(const std::string& type_string, Type& type) {
|
||||
constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP};
|
||||
constexpr const char* kStrings[] = {"f32", "bf16", "sfp"};
|
||||
constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings);
|
||||
static char kErrorMessageBuffer[kNum * 8 + 100] =
|
||||
"Invalid or missing type, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||
strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT
|
||||
strcat(kErrorMessageBuffer, ", "); // NOLINT
|
||||
}
|
||||
strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT
|
||||
strcat(kErrorMessageBuffer, "."); // NOLINT
|
||||
std::string type_lc = type_string;
|
||||
std::transform(begin(type_lc), end(type_lc), begin(type_lc),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
for (size_t i = 0; i < kNum; i++) {
|
||||
if (kStrings[i] == type_lc) {
|
||||
type = kTypes[i];
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
107
gemma/common.h
107
gemma/common.h
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
|
@ -37,56 +38,86 @@ ByteStorageT AllocateSizeof() {
|
|||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
// Returns the return value of Func<T>().operator() called with `args`, where
|
||||
// `T` is selected based on `model`.
|
||||
// Tensor types for loading weights.
|
||||
enum class Type { kF32, kBF16, kSFP };
|
||||
|
||||
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
||||
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
||||
// but can also be called directly when FuncT does not actually use TWeight.
|
||||
//
|
||||
// This is used to implement type-erased functions such as
|
||||
// LoadCompressedWeights, which can be called from other .cc files, by calling a
|
||||
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
|
||||
// be a functor because function templates cannot be passed as a template
|
||||
// template argument, and we prefer to avoid the overhead of std::function.
|
||||
// Note that a T prefix indicates a concrete type template argument, whereas a
|
||||
// T suffix indicates the argument is itself a template.
|
||||
//
|
||||
// This function avoids having to update all call sites when we extend `Model`.
|
||||
template <template <typename Config> class Func, typename... Args>
|
||||
decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
|
||||
// `FuncT` must be a functor because function templates cannot be passed as a
|
||||
// template template argument, and we prefer to avoid the overhead of
|
||||
// std::function.
|
||||
template <typename TWeight, template <typename TConfig> class FuncT,
|
||||
typename... TArgs>
|
||||
decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_TINY:
|
||||
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...);
|
||||
return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_2B:
|
||||
return Func<ConfigGemma2B>()(std::forward<Args>(args)...);
|
||||
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_7B:
|
||||
return Func<ConfigGemma7B>()(std::forward<Args>(args)...);
|
||||
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...);
|
||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
// Like CallFunctorForModel, but for SIMD function templates. This is a macro
|
||||
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
||||
// normal function argument.
|
||||
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \
|
||||
// Returns the return value of FuncT<TConfig>().operator()(args),
|
||||
// where `TConfig` is selected based on `model` and `weight`.
|
||||
|
||||
// This makes it easy to extend `Model` or `Type` without updating callers.
|
||||
//
|
||||
// Usage example: LoadWeights is type-erased so that it can be called from other
|
||||
// .cc files. It uses this function to call the appropriate instantiation of a
|
||||
// template functor LoadCompressedWeightsT<TConfig>.
|
||||
template <template <typename TConfig> class FuncT, typename... TArgs>
|
||||
decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||
TArgs&&... args) {
|
||||
switch (weight) {
|
||||
case Type::kF32:
|
||||
return CallForModel<float, FuncT, TArgs...>( //
|
||||
model, std::forward<TArgs>(args)...);
|
||||
case Type::kBF16:
|
||||
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
|
||||
model, std::forward<TArgs>(args)...);
|
||||
case Type::kSFP:
|
||||
return CallForModel<SfpStream, FuncT, TArgs...>(
|
||||
model, std::forward<TArgs>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
|
||||
}
|
||||
}
|
||||
|
||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
|
||||
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
|
||||
switch (MODEL) { \
|
||||
case Model::GEMMA_TINY: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_7B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GRIFFIN_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
|
|
@ -94,10 +125,42 @@ decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
|
|||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
||||
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
|
||||
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
||||
// normal function argument. MODEL and WEIGHT are enums.
|
||||
#define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
|
||||
switch (WEIGHT) { \
|
||||
case Type::kF32: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
|
||||
break; \
|
||||
case Type::kBF16: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
|
||||
break; \
|
||||
case Type::kSFP: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
|
||||
break; \
|
||||
default: \
|
||||
HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
const char* ParseType(const std::string& type_string, Type& type);
|
||||
|
||||
static inline const char* StringFromType(Type type) {
|
||||
switch (type) {
|
||||
case Type::kF32:
|
||||
return "f32";
|
||||
case Type::kBF16:
|
||||
return "bf16";
|
||||
case Type::kSFP:
|
||||
return "sfp";
|
||||
default:
|
||||
return "?";
|
||||
}
|
||||
}
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
|
|
|
|||
|
|
@ -62,14 +62,16 @@ struct Args : public ArgsBase<Args> {
|
|||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
gcpp::Model ModelType() const { return model_type; }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
ModelTraining model_training;
|
||||
const char* parse_result =
|
||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||
if (parse_result) return parse_result;
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||
model_training)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
return err;
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the uncompressed model.";
|
||||
}
|
||||
|
|
@ -86,7 +88,7 @@ struct Args : public ArgsBase<Args> {
|
|||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type_str;
|
||||
Model model_type;
|
||||
std::string weight_type_str;
|
||||
size_t num_threads;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -101,6 +103,9 @@ struct Args : public ArgsBase<Args> {
|
|||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||
" Required argument.");
|
||||
|
|
@ -110,6 +115,14 @@ struct Args : public ArgsBase<Args> {
|
|||
"number of suupported concurrent threads.",
|
||||
2);
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
void ShowHelp(gcpp::Args& args) {
|
||||
|
|
@ -132,7 +145,7 @@ namespace HWY_NAMESPACE {
|
|||
template <class TConfig>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
|
|
@ -147,7 +160,7 @@ void CompressWeights(const Path& weights_path,
|
|||
// Get weights, compress, and store.
|
||||
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
||||
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
|
||||
weights_path, model_type, pool, scale_for_compression);
|
||||
weights_path, model_type, weight_type, pool, scale_for_compression);
|
||||
WeightsF<TConfig>* weights =
|
||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
Compressor compressor(pool);
|
||||
|
|
@ -169,9 +182,10 @@ namespace gcpp {
|
|||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
const Model model_type = args.ModelType();
|
||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
||||
model_type, CompressWeights,
|
||||
(args.weights, args.compressed_weights, model_type, pool));
|
||||
const Type weight_type = args.WeightType();
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type, weight_type, CompressWeights,
|
||||
(args.weights, args.compressed_weights, model_type, weight_type, pool));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
#include <array>
|
||||
|
||||
#include "compression/compress.h" // SfpStream
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -42,18 +41,10 @@ namespace gcpp {
|
|||
#define GEMMA_MAX_THREADS 128
|
||||
#endif // !GEMMA_MAX_THREADS
|
||||
|
||||
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
||||
#ifndef GEMMA_WEIGHT_T
|
||||
#define GEMMA_WEIGHT_T SfpStream
|
||||
#endif // !GEMMA_WEIGHT_T
|
||||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||
|
||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
|
||||
enum class LayerAttentionType {
|
||||
|
|
@ -82,7 +73,10 @@ constexpr size_t NumLayersOfTypeBefore(
|
|||
return count;
|
||||
}
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
|
|
@ -111,10 +105,12 @@ struct ConfigGemma7B {
|
|||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2B {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
|
|
@ -143,10 +139,12 @@ struct ConfigGemma2B {
|
|||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemmaTiny {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
|
|
@ -175,10 +173,12 @@ struct ConfigGemmaTiny {
|
|||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGriffin2B {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
|
|
@ -235,7 +235,6 @@ struct ConfigGriffin2B {
|
|||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
using WeightT = GEMMA_WEIGHT_T;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,13 +15,20 @@
|
|||
|
||||
#include "gemma/cross_entropy.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
|
@ -63,7 +70,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
|||
auto stream_token = [](int, float) { return true; };
|
||||
auto accept_token = [](int) { return true; };
|
||||
|
||||
const int vocab_size = CallFunctorForModel<GetVocabSize>(gemma.ModelType());
|
||||
// TWeight is unused, but we have to pass it to Config*.
|
||||
const int vocab_size =
|
||||
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
|
||||
float cross_entropy = std::log(vocab_size); // first token
|
||||
size_t pos = 1;
|
||||
std::function<int(const float*, size_t)> sample_token =
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CROSS_ENTROPY_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
|
|
|||
|
|
@ -147,8 +147,10 @@ struct CreateKVCache {
|
|||
|
||||
} // namespace
|
||||
|
||||
KVCache KVCache::Create(Model type) {
|
||||
return CallFunctorForModel<CreateKVCache>(type);
|
||||
KVCache KVCache::Create(Model model_type) {
|
||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||
// use TConfig::Weight.
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type);
|
||||
}
|
||||
|
||||
class GemmaTokenizer::Impl {
|
||||
|
|
@ -727,7 +729,7 @@ Activations<TConfig, kBatchSize>& GetActivations(const ByteStorageT& state_u8) {
|
|||
} // namespace
|
||||
|
||||
template <class TConfig>
|
||||
void Generate(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||
void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||
const ByteStorageT& decode_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
|
|
@ -871,23 +873,31 @@ struct AllocateDecode {
|
|||
} // namespace
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(tokenizer_path), model_type_(model_type) {
|
||||
weights_u8_ = LoadWeights(weights, model_type, pool);
|
||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
||||
Type weight_type, hwy::ThreadPool& pool)
|
||||
: pool_(pool),
|
||||
tokenizer_(tokenizer_path),
|
||||
model_type_(model_type),
|
||||
weight_type_(weight_type) {
|
||||
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
|
||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||
hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
|
||||
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
||||
: pool_(pool),
|
||||
tokenizer_(std::move(tokenizer)),
|
||||
model_type_(model_type),
|
||||
weight_type_(weight_type) {
|
||||
weights_u8_ =
|
||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
CallFunctorForModel<DeleteLayersPtrs>(model_type_, weights_u8_);
|
||||
CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
|
||||
weights_u8_);
|
||||
}
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
|
|
@ -896,8 +906,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
LayersOutputT* layers_output) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
||||
model_type_, Generate,
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type_, weight_type_, GenerateT,
|
||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
||||
kv_cache, pool_, timing_info, layers_output));
|
||||
|
||||
|
|
|
|||
|
|
@ -107,10 +107,11 @@ using LayersOutputT =
|
|||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool);
|
||||
Type weight_type, hwy::ThreadPool& pool);
|
||||
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||
hwy::ThreadPool& pool);
|
||||
~Gemma();
|
||||
|
||||
Model ModelType() const { return model_type_; }
|
||||
|
|
@ -136,6 +137,7 @@ class Gemma {
|
|||
ByteStorageT prefill_u8_;
|
||||
ByteStorageT decode_u8_;
|
||||
Model model_type_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
// DEPRECATED, call Gemma::Generate directly.
|
||||
|
|
|
|||
|
|
@ -38,8 +38,7 @@ class GemmaTest : public ::testing::Test {
|
|||
: weights("./2b-it-mqa.sbs"),
|
||||
tokenizer("./tokenizer.spm"),
|
||||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||
model_type(gcpp::Model::GEMMA_2B),
|
||||
model(tokenizer, weights, model_type, pool) {
|
||||
model(tokenizer, weights, model_type, weight_type, pool) {
|
||||
KVCache kv_cache = KVCache::Create(model_type);
|
||||
}
|
||||
|
||||
|
|
@ -96,7 +95,8 @@ class GemmaTest : public ::testing::Test {
|
|||
gcpp::Path tokenizer;
|
||||
gcpp::KVCache kv_cache;
|
||||
hwy::ThreadPool pool;
|
||||
gcpp::Model model_type = {};
|
||||
gcpp::Model model_type = gcpp::Model::GEMMA_2B;
|
||||
gcpp::Type weight_type = gcpp::Type::kSFP;
|
||||
gcpp::Gemma model;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
<< hwy::VectorBytes() * 8 << " bits)" << "\n"
|
||||
<< "Compiled config : " << CompiledConfig() << "\n"
|
||||
<< "Weight Type : "
|
||||
<< gcpp::TypeName(gcpp::GemmaWeightT()) << "\n"
|
||||
<< gcpp::StringFromType(loader.WeightType()) << "\n"
|
||||
<< "EmbedderInput Type : "
|
||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
||||
}
|
||||
|
|
@ -251,8 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
PinWorkersToCores(pool);
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
|
|
@ -278,7 +277,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(
|
||||
model, loader.ModelTraining(), kv_cache, pool, inference, app.verbosity,
|
||||
model, loader.ModelTrainingType(), kv_cache, pool, inference,
|
||||
app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -157,8 +157,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
PinWorkersToCores(pool);
|
||||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
|
||||
|
|
|
|||
|
|
@ -173,10 +173,11 @@ struct LoadRawWeightsT {
|
|||
#undef SCALE_WEIGHTS
|
||||
} // namespace
|
||||
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool, bool scale_for_compression) {
|
||||
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool,
|
||||
scale_for_compression);
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool,
|
||||
bool scale_for_compression) {
|
||||
return CallForModelAndWeight<LoadRawWeightsT>(
|
||||
model_type, weight_type, weights, pool, scale_for_compression);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
@ -227,17 +228,18 @@ struct LoadCompressedWeightsT {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool);
|
||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type,
|
||||
weights, pool);
|
||||
}
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
if constexpr (kWeightsAreCompressed) {
|
||||
return LoadCompressedWeights(weights, model, pool);
|
||||
return LoadCompressedWeights(weights, model_type, weight_type, pool);
|
||||
} else {
|
||||
return LoadRawWeights(weights, model, pool,
|
||||
return LoadRawWeights(weights, model_type, weight_type, pool,
|
||||
/*scale_for_compression=*/false);
|
||||
}
|
||||
}
|
||||
|
|
@ -274,8 +276,9 @@ struct LogWeightStatsT {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
|
||||
CallFunctorForModel<LogWeightStatsT>(model, weights);
|
||||
void LogWeightStats(gcpp::Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights) {
|
||||
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -129,21 +129,17 @@ using WeightsF = Weights<float, TConfig>;
|
|||
// ----------------------------------------------------------------------------
|
||||
// Compressed
|
||||
|
||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
|
||||
// not yet support smaller compressed types, or require at least bf16. When
|
||||
// weights are f32, we also want such tensors to be f32.
|
||||
template <class TConfig>
|
||||
using WeightF32OrBF16T =
|
||||
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
|
||||
hwy::bfloat16_t>;
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedLayer {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
using TLayer = gcpp::LayerF<TConfig>;
|
||||
using WeightT = typename TConfig::WeightT;
|
||||
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>;
|
||||
using Weight = typename TConfig::Weight;
|
||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
||||
// do not yet support smaller compressed types, or require at least bf16. When
|
||||
// weights are f32, we also want such tensors to be f32.
|
||||
using WeightF32OrBF16 =
|
||||
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
||||
|
||||
static constexpr size_t kHeads = TLayer::kHeads;
|
||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
||||
|
|
@ -166,29 +162,29 @@ struct CompressedLayer {
|
|||
|
||||
union {
|
||||
struct {
|
||||
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<Weight, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<Weight, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
||||
ArrayT<float, kGriffinDim> conv_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
ArrayT<Weight, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all Weight.
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
|
||||
|
|
@ -241,7 +237,7 @@ template <class TConfig>
|
|||
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||
WeightsF<TConfig>>;
|
||||
|
||||
// Call via CallFunctorForModel.
|
||||
// TODO: can we use TConfig::Weight instead of T?
|
||||
template <typename T, typename TConfig>
|
||||
struct AllocateWeights {
|
||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||
|
|
@ -335,14 +331,15 @@ class WeightsWrapper {
|
|||
};
|
||||
|
||||
// For use by compress_weights.cc.
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool, bool scale_for_compression);
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool,
|
||||
bool scale_for_compression);
|
||||
|
||||
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool);
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool);
|
||||
|
||||
void LogWeightStats(Model model, const ByteStorageT& weights);
|
||||
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Iterators
|
||||
|
|
|
|||
46
util/app.h
46
util/app.h
|
|
@ -21,19 +21,17 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#if HWY_OS_LINUX
|
||||
#include <sched.h>
|
||||
|
||||
#include <cctype>
|
||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||
#include <string>
|
||||
#endif // HWY_OS_LINUX
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <string>
|
||||
#include <thread> // NOLINT>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/args.h"
|
||||
|
|
@ -49,14 +47,10 @@ static inline const char* CompiledConfig() {
|
|||
return "msan";
|
||||
} else if (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
#if defined(HWY_IS_HWASAN)
|
||||
} else if (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
#endif
|
||||
#if defined(HWY_IS_UBSAN)
|
||||
} else if (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
#endif
|
||||
} else if (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
|
|
@ -172,15 +166,15 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
gcpp::Model ModelType() const { return model_type; }
|
||||
|
||||
gcpp::ModelTraining ModelTraining() const { return model_training; }
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
const char* parse_result =
|
||||
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||
if (parse_result) return parse_result;
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||
model_training_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
return err;
|
||||
}
|
||||
if (tokenizer.path.empty()) {
|
||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||
}
|
||||
|
|
@ -209,8 +203,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type_str;
|
||||
Model model_type;
|
||||
enum ModelTraining model_training;
|
||||
std::string weight_type_str;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -227,9 +220,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||
" Required argument.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
ModelTraining model_training_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
static inline Gemma CreateGemma(const LoaderArgs& loader,
|
||||
hwy::ThreadPool& pool) {
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
|
||||
loader.WeightType(), pool);
|
||||
}
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue