Merge pull request #58 from google:dev-cleanup

PiperOrigin-RevId: 610942948
This commit is contained in:
Copybara-Service 2024-02-27 18:47:38 -08:00
commit f4a14bfdf2
4 changed files with 59 additions and 39 deletions

1
.clang-format Normal file
View File

@ -0,0 +1 @@
BasedOnStyle: Google

View File

@ -71,19 +71,31 @@ The implementation code is roughly split into 4 layers, from high to low level:
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of 4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
highway) supporting the implementations in (3). highway) supporting the implementations in (3).
Besides these layers, supporting utilities are:
- `compression/` - model compression operations. The 8-bit switched floating
point model conversion is here.
- `util/` - command line argument handling and any other utilities.
## Style and Formatting
A `.clang-format` configuration is provided with our defaults, please run source
files through `clang-format` (or a formatter that produces equivalent behavior)
before finalizing PR for submission.
## Compile-Time Flags (Advanced) ## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system): be exposed to the build system):
- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as - `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as
WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream` WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream`
(default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to (default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to
enable for higher-fidelity (but slower) bfloat16 support. This is defined in enable for higher-fidelity (but slower) bfloat16 support. This is defined in
`gemma.h`. `gemma.h`.
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV - `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet. through `CMakeLists.txt` yet.
In the medium term both of these will likely be deprecated in favor of handling In the medium term both of these will likely be deprecated in favor of handling
options at runtime - allowing for multiple weight compression schemes in a single options at runtime - allowing for multiple weight compression schemes in a single

View File

@ -21,7 +21,7 @@
// Allow changing pre-allocated kv cache size as a compiler flag // Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN #ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096 #define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN #endif // !GEMMA_MAX_SEQLEN
#include <stddef.h> #include <stddef.h>
@ -34,10 +34,10 @@ struct ConfigGemma7B {
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 28; static constexpr int kLayers = 28;
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16; static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1; static constexpr int kTopK = 1;
}; };
@ -46,13 +46,13 @@ struct ConfigGemma2B {
static constexpr int kVocabSize = 256128; static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 18; static constexpr int kLayers = 18;
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8; static constexpr int kHeads = 8;
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1; static constexpr int kTopK = 1;
}; };
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

57
gemma.h
View File

@ -26,15 +26,19 @@
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream #include "compression/compress.h" // SfpStream/NuqStream
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen #include "configs.h" // kSeqLen
// copybara:end
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase #include "util/args.h" // ArgsBase
// copybara:end
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece // copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h" #include "src/sentencepiece_processor.h"
// copybara:end
namespace gcpp { namespace gcpp {
@ -42,7 +46,7 @@ namespace gcpp {
// float, hwy::bfloat16_t, SfpStream, NuqStream // float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T #ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream #define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T #endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T; using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t; using EmbedderInputT = hwy::bfloat16_t;
@ -51,9 +55,9 @@ constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]> hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
}; };
// Model variants: see configs.h for details. // Model variants: see configs.h for details.
@ -61,9 +65,9 @@ enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT }; enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string &text) { static std::string ToLower(const std::string& text) {
std::string result = text; std::string result = text;
std::transform(begin(result), end(result), begin(result), std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });
@ -89,7 +93,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char *Validate() const { const char* Validate() const {
const std::string model_type_lc = ToLower(model_type); const std::string model_type_lc = ToLower(model_type);
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") { model_type_lc != "2b-it" && model_type_lc != "7b-it") {
@ -111,11 +115,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
Path tokenizer; Path tokenizer;
Path model; // uncompressed weights OR Path model; // uncompressed weights OR
Path cache; // compressed weights Path cache; // compressed weights
std::string model_type; std::string model_type;
template <class Visitor> void ForEach(const Visitor &visitor) { template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(), visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)"); "Path name of tokenizer model file. (required)");
visitor( visitor(
@ -138,10 +143,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
struct GemmaInterface; struct GemmaInterface;
struct Gemma { struct Gemma {
Gemma(const LoaderArgs &args, hwy::ThreadPool &pool); Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor &Tokenizer() const; const sentencepiece::SentencePieceProcessor& Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training; gcpp::ModelTraining model_training;
@ -153,7 +158,7 @@ using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>; using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); } InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
@ -163,7 +168,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn; bool multiturn;
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char *Validate() const { const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) { if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see " return "max_tokens is larger than the maximum sequence length (see "
"configs.h)."; "configs.h).";
@ -175,7 +180,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
return nullptr; return nullptr;
} }
template <class Visitor> void ForEach(const Visitor &visitor) { template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072}, visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation."); "Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
@ -186,19 +192,20 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Make top-k sampling deterministic", 2); "Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false, visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every " "Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)\n Default = 0 (conversation resets every turn)"); "interaction without quitting)\n Default = 0 (conversation "
"resets every turn)");
} }
}; };
void GenerateGemma(Gemma &gemma, const InferenceArgs &args, void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int> &prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc &stream_token, const StreamFunc& stream_token,
const AcceptFunc &accept_token, std::mt19937 &g, const AcceptFunc& accept_token, std::mt19937& g,
int verbosity); int verbosity);
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_