mirror of https://github.com/google/gemma.cpp.git
Implement the Griffin model.
Also implement support for some model variations: - Local attention. - Add support for biases. - Use RoPE only on half vectors. - Support different order of QKV weights. Co-authored-by: Andrey Mikhaylov <amik@google.com> Co-authored-by: Martin Bruse <zondolfin@gmail.com> Co-authored-by: Zoltan Szabadka <szabadka@google.com>
This commit is contained in:
parent
4326249d0a
commit
9c3f969405
23
README.md
23
README.md
|
|
@ -241,6 +241,24 @@ Example invocation for the following configuration:
|
||||||
--model 2b-it
|
--model 2b-it
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### RecurrentGemma
|
||||||
|
|
||||||
|
This repository includes a version of Gemma based on Griffin
|
||||||
|
([paper](https://arxiv.org/abs/2402.19427),
|
||||||
|
[code](https://github.com/google-deepmind/recurrentgemma)). Its architecture
|
||||||
|
includes both recurrent layers and local attention, thus it is more efficient
|
||||||
|
for longer sequences and has a smaller memory footprint than standard Gemma. We
|
||||||
|
here provide a C++ implementation of this model based on the paper.
|
||||||
|
|
||||||
|
To use the recurrent version of Gemma included in this repository, build the
|
||||||
|
gemma binary as noted above in Step 3. Download the compressed weights and
|
||||||
|
tokenizer from
|
||||||
|
[Kaggle](https://www.kaggle.com/models/google/recurrentgemma/gemmaCpp) as in
|
||||||
|
Step 1, and run the binary as follows:
|
||||||
|
|
||||||
|
`./gemma --tokenizer tokenizer.spm --model gr2b-it --compressed_weights 2b-it-sfp.sbs`
|
||||||
|
|
||||||
|
|
||||||
### Troubleshooting and FAQs
|
### Troubleshooting and FAQs
|
||||||
|
|
||||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||||
|
|
@ -478,4 +496,9 @@ gemma.cpp was started in fall 2023 by [Austin Huang](mailto:austinvhuang@google.
|
||||||
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
and [Jan Wassenberg](mailto:janwas@google.com), and subsequently released February 2024
|
||||||
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
thanks to contributions from Phil Culliton, Paul Chang, and Dan Zheng.
|
||||||
|
|
||||||
|
Griffin support was implemented in April 2024 thanks to contributions by Andrey
|
||||||
|
Mikhaylov, Eugene Kliuchnikov, Jan Wassenberg, Jyrki Alakuijala, Lode
|
||||||
|
Vandevenne, Luca Versari, Martin Bruse, Phil Culliton, Sami Boukortt, Thomas
|
||||||
|
Fischbacher and Zoltan Szabadka.
|
||||||
|
|
||||||
This is not an officially supported Google product.
|
This is not an officially supported Google product.
|
||||||
|
|
|
||||||
17
benchmark.cc
17
benchmark.cc
|
|
@ -10,14 +10,14 @@
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h"
|
#include "gemma.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/app.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
|
@ -259,6 +259,13 @@ int main(int argc, char** argv) {
|
||||||
gcpp::AppArgs app(argc, argv);
|
gcpp::AppArgs app(argc, argv);
|
||||||
BenchmarkArgs benchmark_args(argc, argv);
|
BenchmarkArgs benchmark_args(argc, argv);
|
||||||
|
|
||||||
|
if (const char* error = loader.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid loader args: %s", error);
|
||||||
|
}
|
||||||
|
if (const char* error = args.Validate()) {
|
||||||
|
HWY_ABORT("\nInvalid inference args: %s", error);
|
||||||
|
}
|
||||||
|
|
||||||
hwy::ThreadPool inner_pool(0);
|
hwy::ThreadPool inner_pool(0);
|
||||||
hwy::ThreadPool pool(app.num_threads);
|
hwy::ThreadPool pool(app.num_threads);
|
||||||
// For many-core, pinning threads to cores helps.
|
// For many-core, pinning threads to cores helps.
|
||||||
|
|
@ -275,7 +282,7 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
if (!benchmark_args.goldens.path.empty()) {
|
if (!benchmark_args.goldens.path.empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" + loader.model_type + ".txt";
|
benchmark_args.goldens.path + "/" + loader.model_type_str + ".txt";
|
||||||
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
return BenchmarkGoldens(model, args, app, kv_cache, inner_pool, pool,
|
||||||
golden_path);
|
golden_path);
|
||||||
} else if (!benchmark_args.summarize_text.path.empty()) {
|
} else if (!benchmark_args.summarize_text.path.empty()) {
|
||||||
|
|
|
||||||
|
|
@ -44,35 +44,14 @@ struct Args : public ArgsBase<Args> {
|
||||||
ChooseNumThreads();
|
ChooseNumThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string ToLower(const std::string& text) {
|
gcpp::Model ModelType() const { return model_type; }
|
||||||
std::string result = text;
|
|
||||||
std::transform(begin(result), end(result), begin(result),
|
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Model ModelType() const {
|
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type_lc.substr(0, 2) == "2b") {
|
|
||||||
return gcpp::Model::GEMMA_2B;
|
|
||||||
} else if (model_type_lc.substr(0, 2) == "7b") {
|
|
||||||
return gcpp::Model::GEMMA_7B;
|
|
||||||
} else {
|
|
||||||
HWY_ABORT("Unknown model type %s", model_type_lc.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() const {
|
const char* Validate() {
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
ModelTraining model_training;
|
||||||
if (model_type.empty()) {
|
const char* parse_result =
|
||||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||||
"2b-it, 7b-it.";
|
if (parse_result) return parse_result;
|
||||||
}
|
|
||||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
|
||||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
|
||||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, 7b-it.";
|
|
||||||
}
|
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
return "Missing --weights flag, a file for the uncompressed model.";
|
return "Missing --weights flag, a file for the uncompressed model.";
|
||||||
}
|
}
|
||||||
|
|
@ -88,7 +67,8 @@ struct Args : public ArgsBase<Args> {
|
||||||
|
|
||||||
Path weights; // uncompressed weights file location
|
Path weights; // uncompressed weights file location
|
||||||
Path compressed_weights; // compressed weights file location
|
Path compressed_weights; // compressed weights file location
|
||||||
std::string model_type;
|
std::string model_type_str;
|
||||||
|
Model model_type;
|
||||||
size_t num_threads;
|
size_t num_threads;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -96,10 +76,12 @@ struct Args : public ArgsBase<Args> {
|
||||||
visitor(weights, "weights", Path(),
|
visitor(weights, "weights", Path(),
|
||||||
"Path name of model weights (.sbs) file.\n"
|
"Path name of model weights (.sbs) file.\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(model_type, "model", std::string(),
|
visitor(model_type_str, "model", std::string(),
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||||
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
|
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"Path name where compressed weights file will be written.\n"
|
"Path name where compressed weights file will be written.\n"
|
||||||
|
|
@ -115,7 +97,7 @@ struct Args : public ArgsBase<Args> {
|
||||||
void ShowHelp(gcpp::Args& args) {
|
void ShowHelp(gcpp::Args& args) {
|
||||||
std::cerr
|
std::cerr
|
||||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||||
" --model <model type> --compressed_weights <output path>\n";
|
" --model <model type> --compressed_weights <output path>\n";
|
||||||
std::cerr << "\n*Arguments*\n\n";
|
std::cerr << "\n*Arguments*\n\n";
|
||||||
args.Help();
|
args.Help();
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
|
|
|
||||||
93
configs.h
93
configs.h
|
|
@ -30,6 +30,8 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
@ -45,16 +47,41 @@ namespace gcpp {
|
||||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||||
|
|
||||||
|
enum class LayerAttentionType {
|
||||||
|
kGemma,
|
||||||
|
kGriffinRecurrentBlock,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <size_t kNum>
|
||||||
|
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||||
|
LayerAttentionType type) {
|
||||||
|
std::array<LayerAttentionType, kNum> config = {};
|
||||||
|
for (LayerAttentionType& l : config) {
|
||||||
|
l = type;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigGemma7B {
|
struct ConfigGemma7B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr int kLayers = 28;
|
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||||
|
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kModelDim = 3072;
|
static constexpr int kModelDim = 3072;
|
||||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||||
static constexpr int kHeads = 16;
|
static constexpr int kHeads = 16;
|
||||||
static constexpr int kKVHeads = 16; // standard MHA
|
static constexpr int kKVHeads = 16; // standard MHA
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
|
||||||
|
// SSM config.
|
||||||
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
static constexpr bool kFFBiases = false;
|
||||||
|
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||||
|
static constexpr bool kUseHalfRope = false;
|
||||||
|
static constexpr bool kUseLocalAttention = false;
|
||||||
|
static constexpr bool kInterleaveQKV = true;
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
};
|
};
|
||||||
|
|
@ -62,17 +89,79 @@ struct ConfigGemma7B {
|
||||||
struct ConfigGemma2B {
|
struct ConfigGemma2B {
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr int kLayers = 18;
|
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||||
|
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||||
static constexpr int kHeads = 8;
|
static constexpr int kHeads = 8;
|
||||||
static constexpr int kKVHeads = 1;
|
static constexpr int kKVHeads = 1;
|
||||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
static constexpr int kTopK = gcpp::kTopK;
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
|
||||||
|
// SSM config.
|
||||||
|
static constexpr int kConv1dWidth = 0;
|
||||||
|
static constexpr bool kFFBiases = false;
|
||||||
|
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||||
|
static constexpr bool kUseHalfRope = false;
|
||||||
|
static constexpr bool kUseLocalAttention = false;
|
||||||
|
static constexpr bool kInterleaveQKV = true;
|
||||||
static constexpr int kNumTensorScales = 0;
|
static constexpr int kNumTensorScales = 0;
|
||||||
using WeightT = GEMMA_WEIGHT_T;
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ConfigGriffin2B {
|
||||||
|
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||||
|
// window.
|
||||||
|
static constexpr int kSeqLen = 2048;
|
||||||
|
static constexpr int kVocabSize = 256000;
|
||||||
|
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGemma,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
};
|
||||||
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
|
static constexpr int kModelDim = 2560;
|
||||||
|
static constexpr int kFFHiddenDim = 7680;
|
||||||
|
static constexpr int kHeads = 10;
|
||||||
|
static constexpr int kKVHeads = 1;
|
||||||
|
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||||
|
static constexpr int kTopK = gcpp::kTopK;
|
||||||
|
|
||||||
|
// SSM config.
|
||||||
|
static constexpr int kConv1dWidth = 4;
|
||||||
|
static constexpr bool kFFBiases = true;
|
||||||
|
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||||
|
static constexpr bool kUseHalfRope = true;
|
||||||
|
static constexpr bool kUseLocalAttention = true;
|
||||||
|
static constexpr bool kInterleaveQKV = false;
|
||||||
|
static constexpr int kNumTensorScales = 140;
|
||||||
|
using WeightT = GEMMA_WEIGHT_T;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
||||||
|
|
|
||||||
519
gemma.cc
519
gemma.cc
|
|
@ -25,12 +25,12 @@
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/args.h" // Path
|
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/args.h" // Path
|
||||||
// copybara:import_next_line:sentencepiece
|
// copybara:import_next_line:sentencepiece
|
||||||
#include "src/sentencepiece_processor.h"
|
#include "src/sentencepiece_processor.h"
|
||||||
// copybara:end
|
// copybara:end
|
||||||
|
|
@ -43,11 +43,12 @@
|
||||||
#include <math.h> // sqrtf
|
#include <math.h> // sqrtf
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdlib>
|
|
||||||
#include <filesystem> // NOLINT
|
#include <filesystem> // NOLINT
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
@ -74,8 +75,35 @@ constexpr bool kDryRunFread = false;
|
||||||
// Setting this to false will load and use uncompressed weights.
|
// Setting this to false will load and use uncompressed weights.
|
||||||
constexpr bool kWeightsAreCompressed = true;
|
constexpr bool kWeightsAreCompressed = true;
|
||||||
|
|
||||||
|
// Set this to true to debug tokenizer tokens.
|
||||||
|
constexpr bool kShowTokenization = false;
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
template <size_t kNumLayers>
|
||||||
|
constexpr size_t NumLayersOfTypeBefore(
|
||||||
|
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||||
|
LayerAttentionType type, size_t num) {
|
||||||
|
size_t count = 0;
|
||||||
|
for (size_t i = 0; i < num; i++) {
|
||||||
|
if (layers[i] == type) count++;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
constexpr size_t NumGemmaLayers() {
|
||||||
|
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
|
||||||
|
LayerAttentionType::kGemma, TConfig::kLayers);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
constexpr size_t NumGriffinLayers() {
|
||||||
|
return NumLayersOfTypeBefore(TConfig::kLayerConfig,
|
||||||
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
|
TConfig::kLayers);
|
||||||
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct Layer {
|
struct Layer {
|
||||||
Layer() = default;
|
Layer() = default;
|
||||||
|
|
@ -96,6 +124,25 @@ struct Layer {
|
||||||
std::array<float, kModelDim * kFFHiddenDim> linear_w;
|
std::array<float, kModelDim * kFFHiddenDim> linear_w;
|
||||||
std::array<float, kModelDim> pre_attention_norm_scale;
|
std::array<float, kModelDim> pre_attention_norm_scale;
|
||||||
std::array<float, kModelDim> pre_ffw_norm_scale;
|
std::array<float, kModelDim> pre_ffw_norm_scale;
|
||||||
|
// These fields are only used by Griffin, and do not affect loading of the
|
||||||
|
// model as it is done per-member.
|
||||||
|
// TODO(veluca): pull weights that are never used at the same time into a
|
||||||
|
// union or otherwise reduce the memory usage.
|
||||||
|
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||||
|
std::array<float, kModelDim> ffw_output_biases;
|
||||||
|
std::array<float, kModelDim> attention_output_biases;
|
||||||
|
|
||||||
|
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
|
||||||
|
std::array<float, kModelDim> griffin_linear_y_biases;
|
||||||
|
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
|
||||||
|
std::array<float, kModelDim> griffin_linear_x_biases;
|
||||||
|
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
|
||||||
|
std::array<float, kModelDim> griffin_linear_out_biases;
|
||||||
|
std::array<float, kModelDim> griffin_conv_biases;
|
||||||
|
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
|
||||||
|
std::array<float, kModelDim * 2> griffin_gate_biases;
|
||||||
|
std::array<float, kModelDim> griffin_a;
|
||||||
|
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||||
};
|
};
|
||||||
|
|
||||||
float ScaleWeights(float* data, size_t len) {
|
float ScaleWeights(float* data, size_t len) {
|
||||||
|
|
@ -196,6 +243,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
||||||
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
||||||
sizeof(weights->final_norm_scale));
|
sizeof(weights->final_norm_scale));
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
Layer<TConfig>* layer_view = weights->GetLayer(layer);
|
Layer<TConfig>* layer_view = weights->GetLayer(layer);
|
||||||
|
|
||||||
#define READ_WEIGHTS(name) \
|
#define READ_WEIGHTS(name) \
|
||||||
|
|
@ -212,16 +260,42 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
||||||
} while (0)
|
} while (0)
|
||||||
// Make sure we don't have uninitialized memory.
|
// Make sure we don't have uninitialized memory.
|
||||||
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
||||||
READ_WEIGHTS(attn_vec_einsum_w);
|
if (type == LayerAttentionType::kGemma) {
|
||||||
READ_WEIGHTS(qkv_einsum_w);
|
READ_WEIGHTS(attn_vec_einsum_w);
|
||||||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
READ_WEIGHTS(qkv_einsum_w);
|
||||||
SCALE_WEIGHTS(qkv_einsum_w);
|
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||||
|
SCALE_WEIGHTS(qkv_einsum_w);
|
||||||
|
} else {
|
||||||
|
READ_WEIGHTS(griffin_linear_x_w);
|
||||||
|
READ_WEIGHTS(griffin_linear_x_biases);
|
||||||
|
READ_WEIGHTS(griffin_linear_y_w);
|
||||||
|
READ_WEIGHTS(griffin_linear_y_biases);
|
||||||
|
READ_WEIGHTS(griffin_linear_out_w);
|
||||||
|
READ_WEIGHTS(griffin_linear_out_biases);
|
||||||
|
READ_WEIGHTS(griffin_conv_w);
|
||||||
|
READ_WEIGHTS(griffin_conv_biases);
|
||||||
|
READ_WEIGHTS(griffin_gate_w);
|
||||||
|
READ_WEIGHTS(griffin_gate_biases);
|
||||||
|
READ_WEIGHTS(griffin_a);
|
||||||
|
SCALE_WEIGHTS(griffin_linear_x_w);
|
||||||
|
SCALE_WEIGHTS(griffin_linear_y_w);
|
||||||
|
SCALE_WEIGHTS(griffin_linear_out_w);
|
||||||
|
SCALE_WEIGHTS(griffin_gate_w);
|
||||||
|
}
|
||||||
READ_WEIGHTS(gating_einsum_w);
|
READ_WEIGHTS(gating_einsum_w);
|
||||||
READ_WEIGHTS(linear_w);
|
READ_WEIGHTS(linear_w);
|
||||||
SCALE_WEIGHTS(gating_einsum_w);
|
SCALE_WEIGHTS(gating_einsum_w);
|
||||||
SCALE_WEIGHTS(linear_w);
|
SCALE_WEIGHTS(linear_w);
|
||||||
READ_WEIGHTS(pre_attention_norm_scale);
|
READ_WEIGHTS(pre_attention_norm_scale);
|
||||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||||
|
if (TConfig::kFFBiases) {
|
||||||
|
READ_WEIGHTS(ffw_gating_biases);
|
||||||
|
READ_WEIGHTS(ffw_output_biases);
|
||||||
|
}
|
||||||
|
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||||
|
type == LayerAttentionType::kGemma) {
|
||||||
|
READ_WEIGHTS(attention_output_biases);
|
||||||
|
}
|
||||||
#undef READ_WEIGHTS
|
#undef READ_WEIGHTS
|
||||||
}
|
}
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
|
|
@ -253,14 +327,30 @@ struct CompressedLayer {
|
||||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||||
|
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||||
|
CompressedArray<float, kModelDim> ffw_output_biases;
|
||||||
|
CompressedArray<float, kModelDim> attention_output_biases;
|
||||||
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||||
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||||
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
|
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
|
||||||
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
|
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||||
|
|
||||||
|
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
|
||||||
|
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
|
||||||
|
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
|
||||||
|
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
|
||||||
|
griffin_gate_w;
|
||||||
|
CompressedArray<float, kModelDim> griffin_a;
|
||||||
|
CompressedArray<float, kModelDim> griffin_linear_y_biases;
|
||||||
|
CompressedArray<float, kModelDim> griffin_linear_x_biases;
|
||||||
|
CompressedArray<float, kModelDim> griffin_linear_out_biases;
|
||||||
|
CompressedArray<float, kModelDim> griffin_conv_biases;
|
||||||
|
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
|
||||||
|
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Array instead of single large allocation for parallel mem init. Split out of
|
// Array instead of single large allocation for parallel mem init. Split out
|
||||||
// CompressedWeights so that only these pointers are initialized, not the
|
// of CompressedWeights so that only these pointers are initialized, not the
|
||||||
// CompressedArray.
|
// CompressedArray.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct CompressedLayerPointers {
|
struct CompressedLayerPointers {
|
||||||
|
|
@ -307,7 +397,8 @@ struct Activations {
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static constexpr size_t kCachePosSize = TConfig::kLayers * kKVHeads * kQKVDim;
|
static constexpr size_t kCachePosSize =
|
||||||
|
NumGemmaLayers<TConfig>() * kKVHeads * kQKVDim;
|
||||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
|
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim;
|
||||||
|
|
||||||
std::array<float, kBatchSize * kModelDim> x; // input
|
std::array<float, kBatchSize * kModelDim> x; // input
|
||||||
|
|
@ -327,6 +418,12 @@ struct Activations {
|
||||||
// bf_ffw_hidden;
|
// bf_ffw_hidden;
|
||||||
std::array<float, kBatchSize * kModelDim> ffw_out;
|
std::array<float, kBatchSize * kModelDim> ffw_out;
|
||||||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||||
|
|
||||||
|
// Griffin layer internal activations
|
||||||
|
std::array<float, kBatchSize * kModelDim> griffin_x;
|
||||||
|
std::array<float, kBatchSize * kModelDim> griffin_y;
|
||||||
|
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
|
||||||
|
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
|
||||||
};
|
};
|
||||||
|
|
||||||
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
||||||
|
|
@ -353,8 +450,13 @@ struct GemmaInterface {
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
KVCache CreateKVCache() {
|
KVCache CreateKVCache() {
|
||||||
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
constexpr size_t kConv1dWidth = Config::kConv1dWidth;
|
||||||
Config::kSeqLen);
|
return CreateKVCache(
|
||||||
|
NumGemmaLayers<Config>() * Config::kKVHeads * Config::kQKVDim,
|
||||||
|
Config::kSeqLen,
|
||||||
|
NumGriffinLayers<Config>() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||||
|
Config::kModelDim,
|
||||||
|
NumGriffinLayers<Config>() * Config::kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
KVCache CreateKVCache(Model type) {
|
KVCache CreateKVCache(Model type) {
|
||||||
|
|
@ -363,6 +465,8 @@ KVCache CreateKVCache(Model type) {
|
||||||
return CreateKVCache<ConfigGemma2B>();
|
return CreateKVCache<ConfigGemma2B>();
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
return CreateKVCache<ConfigGemma7B>();
|
return CreateKVCache<ConfigGemma7B>();
|
||||||
|
case Model::GRIFFIN_2B:
|
||||||
|
return CreateKVCache<ConfigGriffin2B>();
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||||
}
|
}
|
||||||
|
|
@ -379,7 +483,15 @@ class GemmaTokenizerImpl : public GemmaTokenizer {
|
||||||
}
|
}
|
||||||
bool Encode(const std::string& input,
|
bool Encode(const std::string& input,
|
||||||
std::vector<int>* pieces) const override {
|
std::vector<int>* pieces) const override {
|
||||||
return impl_->Encode(input, pieces).ok();
|
if constexpr (kShowTokenization) {
|
||||||
|
bool is_ok = impl_->Encode(input, pieces).ok();
|
||||||
|
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
|
||||||
|
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
|
||||||
|
}
|
||||||
|
return is_ok;
|
||||||
|
} else {
|
||||||
|
return impl_->Encode(input, pieces).ok();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Given a sequence of ids, decodes it into a detokenized output.
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
bool Decode(const std::vector<int>& ids,
|
bool Decode(const std::vector<int>& ids,
|
||||||
|
|
@ -442,6 +554,119 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||||
|
HWY_NOINLINE void GriffinRecurrent(
|
||||||
|
size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool) {
|
||||||
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
HWY_DASSERT(batch_idx < kBatchSize);
|
||||||
|
static constexpr size_t kModelDim =
|
||||||
|
gcpp::Activations<TConfig, kBatchSize>::kModelDim;
|
||||||
|
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
|
||||||
|
// X / Y linear layers.
|
||||||
|
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||||
|
TwoMatVecAdd<true, kModelDim, kModelDim>(
|
||||||
|
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
|
||||||
|
activations.pre_att_rms_out.data() + batch_offset,
|
||||||
|
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
|
||||||
|
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
|
||||||
|
/*out1=*/y, pool);
|
||||||
|
Gelu(y, kModelDim);
|
||||||
|
|
||||||
|
// Conv1D.
|
||||||
|
{
|
||||||
|
HWY_FULL(float) df;
|
||||||
|
HWY_DASSERT(kModelDim % Lanes(df) == 0);
|
||||||
|
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
|
||||||
|
|
||||||
|
// cache[i] = input at time t-i.
|
||||||
|
float* HWY_RESTRICT cache[HWY_MAX(kConv1dWidth, 1)];
|
||||||
|
cache[0] = x;
|
||||||
|
for (size_t i = 1; i < kConv1dWidth; i++) {
|
||||||
|
cache[i] =
|
||||||
|
kv_cache.conv1d_cache.get() + layer_offset +
|
||||||
|
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||||
|
auto xv = hn::Load(df, x + i);
|
||||||
|
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
|
||||||
|
auto accum1 = hn::Zero(df);
|
||||||
|
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||||
|
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||||
|
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||||
|
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||||
|
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||||
|
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||||
|
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||||
|
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||||
|
}
|
||||||
|
hn::Store(hn::Add(accum0, accum1), df, x + i);
|
||||||
|
hn::Store(xv, df, cache[HWY_MAX(kConv1dWidth, 1) - 1] + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RGLRU
|
||||||
|
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT a = activations.griffin_multiplier.data() + batch_offset;
|
||||||
|
float* HWY_RESTRICT rnn_state =
|
||||||
|
kv_cache.rglru_cache.get() + layer * kModelDim;
|
||||||
|
|
||||||
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
constexpr size_t kHeadDim = kModelDim / kHeads;
|
||||||
|
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||||
|
size_t head_offset = head * kHeadDim;
|
||||||
|
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
|
||||||
|
layer_weights->griffin_gate_w, kMatrixSize * head,
|
||||||
|
kMatrixSize * (kHeads + head), x + head_offset,
|
||||||
|
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
|
||||||
|
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
|
||||||
|
head_offset,
|
||||||
|
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||||
|
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||||
|
Sigmoid(a + head_offset, kHeadDim);
|
||||||
|
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||||
|
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||||
|
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||||
|
layer_weights->griffin_a.data() + head_offset, fn_mul);
|
||||||
|
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||||
|
fn_mul);
|
||||||
|
// RNN scan
|
||||||
|
HWY_FULL(float) df;
|
||||||
|
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
|
||||||
|
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
|
||||||
|
auto log_a = hn::Load(df, a + head_offset + i);
|
||||||
|
auto gated_x = hn::Load(df, x + head_offset + i);
|
||||||
|
auto rnn = hn::Load(df, rnn_state + head_offset + i);
|
||||||
|
auto a = hn::Exp(df, log_a);
|
||||||
|
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
|
||||||
|
if (pos == 0) {
|
||||||
|
x_multiplier = hn::Set(df, 1.0);
|
||||||
|
}
|
||||||
|
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
|
||||||
|
hn::Store(new_x, df, rnn_state + head_offset + i);
|
||||||
|
|
||||||
|
// Join branches.
|
||||||
|
auto yv = hn::Load(df, y + head_offset + i);
|
||||||
|
auto pre_out = hn::Mul(yv, new_x);
|
||||||
|
hn::Store(pre_out, df, x + head_offset + i);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Final linear layer.
|
||||||
|
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||||
|
MatVecAdd<true, kModelDim, kModelDim>(
|
||||||
|
layer_weights->griffin_linear_out_w, 0, x,
|
||||||
|
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||||
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
Activations<TConfig, kBatchSize>& activations,
|
Activations<TConfig, kBatchSize>& activations,
|
||||||
|
|
@ -462,6 +687,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
static const float kQueryScale =
|
static const float kQueryScale =
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
|
|
||||||
|
size_t cache_pos = pos;
|
||||||
|
size_t cache_num = pos + 1;
|
||||||
|
if constexpr (TConfig::kUseLocalAttention) {
|
||||||
|
cache_pos %= TConfig::kSeqLen;
|
||||||
|
cache_num = std::min(cache_num, static_cast<size_t>(TConfig::kSeqLen));
|
||||||
|
}
|
||||||
|
|
||||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
|
|
||||||
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
|
|
@ -480,7 +712,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, k_offset,
|
||||||
v_offset, x, k, v);
|
v_offset, x, k, v);
|
||||||
|
|
||||||
Rope(k, kQKVDim, pos);
|
Rope(k, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
|
|
@ -491,24 +723,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
head * TConfig::kSeqLen +
|
head * TConfig::kSeqLen +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
|
|
||||||
Rope(q, kQKVDim, pos);
|
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2] = score;
|
||||||
}
|
}
|
||||||
Softmax(head_att, pos + 1);
|
Softmax(head_att, cache_num);
|
||||||
|
|
||||||
// Weighted summation
|
// Weighted summation
|
||||||
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim +
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 < cache_num; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||||
|
|
@ -520,22 +752,34 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
head == 0
|
head == 0
|
||||||
? activations.att_post2.data() + batch_idx * kModelDim
|
? activations.att_post2.data() + batch_idx * kModelDim
|
||||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
if (head == 0) {
|
||||||
head * kModelDim * kQKVDim, att_out,
|
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
head_out);
|
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
|
||||||
|
layer_weights->attention_output_biases.data(), head_out);
|
||||||
|
} else {
|
||||||
|
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||||
|
head * kModelDim * kQKVDim, att_out,
|
||||||
|
head_out);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
if constexpr (kHeads == kKVHeads) {
|
||||||
// Multi-Head Attention
|
// Multi-Head Attention
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
|
// linear projections to QKV
|
||||||
|
const size_t head_offset = TConfig::kInterleaveQKV
|
||||||
|
? 3 * kQKVDim * kModelDim
|
||||||
|
: kQKVDim * kModelDim;
|
||||||
|
const size_t mat_offset =
|
||||||
|
TConfig::kInterleaveQKV ? kQKVDim * kModelDim : kModelDim * kModelDim;
|
||||||
|
const size_t q_offset = head * head_offset + 0 * mat_offset;
|
||||||
|
const size_t k_offset = head * head_offset + 1 * mat_offset;
|
||||||
|
const size_t v_offset = head * head_offset + 2 * mat_offset;
|
||||||
|
|
||||||
ProjQ(head, head_offset);
|
ProjQ(head, q_offset);
|
||||||
|
|
||||||
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
|
|
||||||
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
|
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
|
|
@ -546,7 +790,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
const size_t kv_offset =
|
||||||
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -581,13 +827,13 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
|
|
||||||
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
// Same matrix, first and second half of rows. Could fuse into one MatVec,
|
||||||
// but separating them could help on NUMA e.g. multiple sockets.
|
// but separating them could help on NUMA e.g. multiple sockets.
|
||||||
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w,
|
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||||
kFFHiddenDim * kModelDim, vec, out_mul,
|
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
|
||||||
pool);
|
layer_weights->ffw_gating_biases.data() + kFFHiddenDim, out_mul, pool);
|
||||||
|
|
||||||
// Gate, will go through the nonlinearity.
|
// Gate, will go through the nonlinearity.
|
||||||
MatVec<kFFHiddenDim, kModelDim>(layer_weights->gating_einsum_w, 0, vec, out,
|
MatVecAdd<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
|
||||||
pool);
|
layer_weights->gating_einsum_w, 0, vec,
|
||||||
|
layer_weights->ffw_gating_biases.data(), out, pool);
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -598,8 +844,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
PROFILER_ZONE("Gen.FFW\\GatedGELU");
|
||||||
MatVec<kModelDim, kFFHiddenDim>(
|
MatVecAdd<TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
|
||||||
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
layer_weights->linear_w, 0, activations.ffw_hidden.data() + hidden_offset,
|
||||||
|
layer_weights->ffw_output_biases.data(),
|
||||||
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
activations.ffw_out.data() + batch_idx * kModelDim, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -639,15 +886,23 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
});
|
});
|
||||||
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
|
size_t layer_of_type =
|
||||||
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
activations.pre_att_rms_out.data() + token_idx * kModelDim,
|
||||||
kModelDim);
|
kModelDim);
|
||||||
Attention<kBatchSize>(pos, token_idx, layer, activations, layer_weights,
|
if (type == LayerAttentionType::kGemma) {
|
||||||
kv_cache, pool);
|
Attention<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||||
|
layer_weights, kv_cache, pool);
|
||||||
|
} else {
|
||||||
|
GriffinRecurrent<kBatchSize>(pos, token_idx, layer_of_type, activations,
|
||||||
|
layer_weights, kv_cache, pool);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
// TODO: sink the loop into these functions, i.e. make them matmuls.
|
||||||
|
|
@ -678,9 +933,7 @@ template <typename WeightArrayT, class TConfig>
|
||||||
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
Activations<TConfig, 1>& activations, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool) {
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data(), kModelDim);
|
activations.x.data(), kModelDim);
|
||||||
|
|
||||||
|
|
@ -688,12 +941,21 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
MulByConst(kEmbScaling, activations.x.data(), kModelDim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
const auto* layer_weights = weights.GetLayer(layer);
|
const auto* layer_weights = weights.GetLayer(layer);
|
||||||
|
size_t layer_of_type =
|
||||||
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNorm(activations.x.data(),
|
RMSNorm(activations.x.data(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data(),
|
||||||
activations.pre_att_rms_out.data(), kModelDim);
|
activations.pre_att_rms_out.data(), kModelDim);
|
||||||
Attention<1>(pos, 0, layer, activations, layer_weights, kv_cache, pool);
|
if (type == LayerAttentionType::kGemma) {
|
||||||
|
Attention<1>(pos, 0, layer_of_type, activations, layer_weights, kv_cache,
|
||||||
|
pool);
|
||||||
|
} else {
|
||||||
|
GriffinRecurrent<1>(pos, 0, layer_of_type, activations, layer_weights,
|
||||||
|
kv_cache, pool);
|
||||||
|
}
|
||||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||||
|
|
@ -707,10 +969,12 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
size_t& prompt_size) {
|
size_t& prompt_size) {
|
||||||
if (max_tokens > TConfig::kSeqLen) {
|
if (!TConfig::kUseLocalAttention) {
|
||||||
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
if (max_tokens > TConfig::kSeqLen) {
|
||||||
max_tokens, TConfig::kSeqLen);
|
fprintf(stderr, "WARNING: max_tokens %zu > kSeqLen %d, truncating.\n",
|
||||||
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
max_tokens, TConfig::kSeqLen);
|
||||||
|
max_tokens = static_cast<size_t>(TConfig::kSeqLen);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_generated_tokens > max_tokens) {
|
if (max_generated_tokens > max_tokens) {
|
||||||
|
|
@ -720,12 +984,14 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
max_generated_tokens = max_tokens - 1;
|
max_generated_tokens = max_tokens - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
if (!TConfig::kUseLocalAttention) {
|
||||||
fprintf(stderr,
|
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
fprintf(stderr,
|
||||||
"%d, truncating.\n",
|
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
||||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
"%d, truncating.\n",
|
||||||
prompt_size = max_tokens - max_generated_tokens;
|
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
||||||
|
prompt_size = max_tokens - max_generated_tokens;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -935,6 +1201,19 @@ void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
accept_token, gen, verbosity);
|
accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GenerateGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma, size_t max_tokens,
|
||||||
|
size_t max_generated_tokens, float temperature,
|
||||||
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
hwy::ThreadPool& inner_pool,
|
||||||
|
const StreamFunc& stream_token,
|
||||||
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
|
int verbosity) {
|
||||||
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
|
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||||
|
accept_token, gen, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
float ComputeCrossEntropy2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
|
@ -951,6 +1230,15 @@ float ComputeCrossEntropy7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
inner_pool, verbosity);
|
inner_pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float ComputeCrossEntropyGriffin2B(GemmaImpl<ConfigGriffin2B>& gemma,
|
||||||
|
size_t max_tokens,
|
||||||
|
const std::vector<int>& prompt,
|
||||||
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
|
hwy::ThreadPool& inner_pool, int verbosity) {
|
||||||
|
return ComputeCrossEntropyImpl(gemma, max_tokens, prompt, kv_cache, pool,
|
||||||
|
inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
// if weights = null, which happens during the first call where we attempt to
|
// if weights = null, which happens during the first call where we attempt to
|
||||||
// load from cache.
|
// load from cache.
|
||||||
|
|
@ -967,6 +1255,7 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
||||||
|
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
const Layer<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||||
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
||||||
|
|
@ -978,9 +1267,33 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
||||||
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
||||||
CALL_FUNC("gating_ein", gating_einsum_w);
|
CALL_FUNC("gating_ein", gating_einsum_w);
|
||||||
CALL_FUNC("linear_w", linear_w);
|
CALL_FUNC("linear_w", linear_w);
|
||||||
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
if (type == LayerAttentionType::kGemma) {
|
||||||
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||||
|
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||||
|
} else {
|
||||||
|
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w);
|
||||||
|
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases);
|
||||||
|
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w);
|
||||||
|
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases);
|
||||||
|
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w);
|
||||||
|
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases);
|
||||||
|
CALL_FUNC("gr_conv_w", griffin_conv_w);
|
||||||
|
CALL_FUNC("gr_conv_b", griffin_conv_biases);
|
||||||
|
CALL_FUNC("gr_gate_w", griffin_gate_w);
|
||||||
|
CALL_FUNC("gr_gate_b", griffin_gate_biases);
|
||||||
|
CALL_FUNC("gr_a", griffin_a);
|
||||||
|
}
|
||||||
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||||
|
|
||||||
|
if (TConfig::kFFBiases) {
|
||||||
|
CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||||
|
CALL_FUNC("ffw_out_b", ffw_output_biases);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||||
|
type == LayerAttentionType::kGemma) {
|
||||||
|
CALL_FUNC("attn_ob", attention_output_biases);
|
||||||
|
}
|
||||||
#undef CALL_FUNC
|
#undef CALL_FUNC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1011,10 +1324,18 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
||||||
if (TConfig::kNumTensorScales > 0) {
|
if (TConfig::kNumTensorScales > 0) {
|
||||||
size_t scale_pos = 0;
|
size_t scale_pos = 0;
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
|
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
|
||||||
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
if (type == LayerAttentionType::kGemma) {
|
||||||
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
|
} else {
|
||||||
|
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]);
|
||||||
|
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]);
|
||||||
|
}
|
||||||
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
||||||
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
||||||
}
|
}
|
||||||
|
|
@ -1031,6 +1352,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeightsT(
|
||||||
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
return LoadCompressedWeights<ConfigGemma2B>(weights, pool);
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
return LoadCompressedWeights<ConfigGemma7B>(weights, pool);
|
||||||
|
case Model::GRIFFIN_2B:
|
||||||
|
return LoadCompressedWeights<ConfigGriffin2B>(weights, pool);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -1044,6 +1367,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeightsT(gcpp::Model model,
|
||||||
return LoadWeights<ConfigGemma2B>(weights, pool);
|
return LoadWeights<ConfigGemma2B>(weights, pool);
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
return LoadWeights<ConfigGemma7B>(weights, pool);
|
return LoadWeights<ConfigGemma7B>(weights, pool);
|
||||||
|
case Model::GRIFFIN_2B:
|
||||||
|
return LoadWeights<ConfigGriffin2B>(weights, pool);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -1089,6 +1414,9 @@ void CompressWeightsT(gcpp::Model model, const Path& weights,
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
CompressWeights<ConfigGemma7B>(weights, compressed_weights, pool);
|
||||||
break;
|
break;
|
||||||
|
case Model::GRIFFIN_2B:
|
||||||
|
CompressWeights<ConfigGriffin2B>(weights, compressed_weights, pool);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
@ -1106,13 +1434,29 @@ HWY_EXPORT(LoadWeightsT);
|
||||||
HWY_EXPORT(CompressWeightsT);
|
HWY_EXPORT(CompressWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
HWY_EXPORT(GenerateGriffin2B);
|
||||||
HWY_EXPORT(ComputeCrossEntropy2B);
|
HWY_EXPORT(ComputeCrossEntropy2B);
|
||||||
HWY_EXPORT(ComputeCrossEntropy7B);
|
HWY_EXPORT(ComputeCrossEntropy7B);
|
||||||
|
HWY_EXPORT(ComputeCrossEntropyGriffin2B);
|
||||||
|
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
|
size_t conv_cache_size, size_t rglru_cache_size) {
|
||||||
KVCache kv_cache = {};
|
KVCache kv_cache = {};
|
||||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
if (size_cache_pos != 0) {
|
||||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
kv_cache.value_cache =
|
||||||
|
hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
}
|
||||||
|
if (conv_cache_size != 0) {
|
||||||
|
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv_cache_size);
|
||||||
|
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||||
|
conv_cache_size * sizeof(kv_cache.conv1d_cache[0]));
|
||||||
|
}
|
||||||
|
if (rglru_cache_size != 0) {
|
||||||
|
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||||
|
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||||
|
rglru_cache_size * sizeof(kv_cache.rglru_cache[0]));
|
||||||
|
}
|
||||||
return kv_cache;
|
return kv_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1136,6 +1480,7 @@ void GemmaImpl<ConfigGemma2B>::Generate(
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
|
|
@ -1148,6 +1493,18 @@ void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void GemmaImpl<ConfigGriffin2B>::Generate(
|
||||||
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
|
std::mt19937& gen, int verbosity) {
|
||||||
|
HWY_DYNAMIC_DISPATCH(GenerateGriffin2B)
|
||||||
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
float GemmaImpl<ConfigGemma2B>::ComputeCrossEntropy(
|
||||||
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
|
@ -1164,6 +1521,14 @@ float GemmaImpl<ConfigGemma7B>::ComputeCrossEntropy(
|
||||||
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
float GemmaImpl<ConfigGriffin2B>::ComputeCrossEntropy(
|
||||||
|
size_t max_tokens, const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, int verbosity) {
|
||||||
|
return HWY_DYNAMIC_DISPATCH(ComputeCrossEntropyGriffin2B)(
|
||||||
|
*this, max_tokens, prompt, kv_cache, pool, inner_pool, verbosity);
|
||||||
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
|
|
@ -1190,6 +1555,9 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
impl_.reset(new GemmaImpl<ConfigGemma7B>(tokenizer, weights_u8, pool));
|
||||||
break;
|
break;
|
||||||
|
case Model::GRIFFIN_2B:
|
||||||
|
impl_.reset(new GemmaImpl<ConfigGriffin2B>(tokenizer, weights_u8, pool));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model_type));
|
||||||
}
|
}
|
||||||
|
|
@ -1240,5 +1608,42 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt",
|
||||||
|
"2b-it", "7b-it", "gr2b-it"};
|
||||||
|
constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B,
|
||||||
|
Model::GRIFFIN_2B, Model::GEMMA_2B,
|
||||||
|
Model::GEMMA_7B, Model::GRIFFIN_2B};
|
||||||
|
constexpr ModelTraining kModelTraining[] = {
|
||||||
|
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
||||||
|
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
|
Model& model, ModelTraining& training) {
|
||||||
|
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
|
||||||
|
static char kErrorMessageBuffer[kNum * 8 + 1024];
|
||||||
|
kErrorMessageBuffer[0] = 0;
|
||||||
|
strcat(kErrorMessageBuffer,
|
||||||
|
"Invalid or missing model flag, need to specify one of ");
|
||||||
|
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||||
|
strcat(kErrorMessageBuffer, kModelFlags[i]);
|
||||||
|
strcat(kErrorMessageBuffer, ", ");
|
||||||
|
}
|
||||||
|
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]);
|
||||||
|
strcat(kErrorMessageBuffer, ".");
|
||||||
|
std::string model_type_lc = model_flag;
|
||||||
|
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
for (size_t i = 0; i < kNum; i++) {
|
||||||
|
if (kModelFlags[i] == model_type_lc) {
|
||||||
|
model = kModelTypes[i];
|
||||||
|
training = kModelTraining[i];
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kErrorMessageBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
18
gemma.h
18
gemma.h
|
|
@ -42,15 +42,24 @@ constexpr bool kSystemPrompt = false;
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
||||||
hwy::AlignedFreeUniquePtr<float[]>
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
|
value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
|
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]>
|
||||||
|
rglru_cache; // kModelDim * kNumGriffinLayers
|
||||||
};
|
};
|
||||||
|
|
||||||
// Model variants: see configs.h for details.
|
// Model variants: see configs.h for details.
|
||||||
enum class Model { GEMMA_2B, GEMMA_7B };
|
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B };
|
||||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
// Thread-hostile.
|
||||||
|
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||||
|
Model& model, ModelTraining& training);
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
@ -79,7 +88,8 @@ struct Gemma {
|
||||||
};
|
};
|
||||||
|
|
||||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||||
|
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||||
|
|
||||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||||
// probability is 0.0f.
|
// probability is 0.0f.
|
||||||
|
|
|
||||||
12
run.cc
12
run.cc
|
|
@ -27,8 +27,6 @@
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "gemma.h" // Gemma
|
#include "gemma.h" // Gemma
|
||||||
// copybara:import_next_line:gemma_cpp
|
|
||||||
#include "util/app.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -36,8 +34,12 @@
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/app.h"
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
|
|
||||||
|
static constexpr bool kVerboseLogTokens = false;
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr std::string_view kAsciiArtBanner = R""(
|
static constexpr std::string_view kAsciiArtBanner = R""(
|
||||||
|
|
@ -203,6 +205,12 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
std::cerr << "\n"
|
std::cerr << "\n"
|
||||||
<< "[ Reading prompt ] " << std::flush;
|
<< "[ Reading prompt ] " << std::flush;
|
||||||
|
|
||||||
|
if constexpr (kVerboseLogTokens) {
|
||||||
|
for (int i = 0; i < static_cast<int>(prompt.size()); ++i) {
|
||||||
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||||
|
|
|
||||||
51
util/app.h
51
util/app.h
|
|
@ -125,46 +125,21 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
static std::string ToLower(const std::string& text) {
|
gcpp::Model ModelType() const { return model_type; }
|
||||||
std::string result = text;
|
|
||||||
std::transform(begin(result), end(result), begin(result),
|
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Model ModelType() const {
|
gcpp::ModelTraining ModelTraining() const { return model_training; }
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
|
||||||
return gcpp::Model::GEMMA_2B;
|
|
||||||
} else {
|
|
||||||
return gcpp::Model::GEMMA_7B;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::ModelTraining ModelTraining() const {
|
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
|
||||||
return gcpp::ModelTraining::GEMMA_PT;
|
|
||||||
} else {
|
|
||||||
return gcpp::ModelTraining::GEMMA_IT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
const char* parse_result =
|
||||||
if (model_type.empty()) {
|
ParseModelTypeAndTraining(model_type_str, model_type, model_training);
|
||||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
if (parse_result) return parse_result;
|
||||||
"2b-it, or 7b-it.";
|
|
||||||
}
|
|
||||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
|
||||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
|
||||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
|
||||||
"7b-it.";
|
|
||||||
}
|
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
}
|
}
|
||||||
|
if (!tokenizer.exists()) {
|
||||||
|
return "Can't open file specified with --tokenizer flag.";
|
||||||
|
}
|
||||||
if (!compressed_weights.path.empty()) {
|
if (!compressed_weights.path.empty()) {
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
weights = compressed_weights;
|
weights = compressed_weights;
|
||||||
|
|
@ -186,7 +161,9 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
Path tokenizer;
|
Path tokenizer;
|
||||||
Path weights; // weights file location
|
Path weights; // weights file location
|
||||||
Path compressed_weights;
|
Path compressed_weights;
|
||||||
std::string model_type;
|
std::string model_type_str;
|
||||||
|
Model model_type;
|
||||||
|
enum ModelTraining model_training;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -196,10 +173,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"Path name of model weights (.sbs) file.\n Required argument.");
|
"Path name of model weights (.sbs) file.\n Required argument.");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"Alias for --weights.");
|
"Alias for --weights.");
|
||||||
visitor(model_type, "model", std::string(),
|
visitor(model_type_str, "model", std::string(),
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||||
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
|
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue