Merge pull request #58 from google:dev-cleanup

PiperOrigin-RevId: 610942948
This commit is contained in:
Copybara-Service 2024-02-27 18:47:38 -08:00
commit f4a14bfdf2
4 changed files with 59 additions and 39 deletions

1
.clang-format Normal file
View File

@ -0,0 +1 @@
BasedOnStyle: Google

View File

@ -71,6 +71,18 @@ The implementation code is roughly split into 4 layers, from high to low level:
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
highway) supporting the implementations in (3).
Besides these layers, supporting utilities are:
- `compression/` - model compression operations. The 8-bit switched floating
point model conversion is here.
- `util/` - command line argument handling and any other utilities.
## Style and Formatting
A `.clang-format` configuration is provided with our defaults, please run source
files through `clang-format` (or a formatter that produces equivalent behavior)
before finalizing PR for submission.
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not

View File

@ -21,7 +21,7 @@
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN
#endif // !GEMMA_MAX_SEQLEN
#include <stddef.h>
@ -34,10 +34,10 @@ struct ConfigGemma7B {
static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 28;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1;
};
@ -46,13 +46,13 @@ struct ConfigGemma2B {
static constexpr int kVocabSize = 256128;
static constexpr int kLayers = 18;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = 1;
};
} // namespace gcpp
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
#endif // THIRD_PARTY_GEMMA_CPP_CONFIGS_H_

57
gemma.h
View File

@ -26,15 +26,19 @@
// copybara:import_next_line:gemma_cpp
#include "compression/compress.h" // SfpStream/NuqStream
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "configs.h" // kSeqLen
#include "configs.h" // kSeqLen
// copybara:end
// copybara:import_next_line:gemma_cpp
#include "util/args.h" // ArgsBase
#include "util/args.h" // ArgsBase
// copybara:end
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
// copybara:import_next_line:sentencepiece
#include "src/sentencepiece_processor.h"
// copybara:end
namespace gcpp {
@ -42,7 +46,7 @@ namespace gcpp {
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
#endif // !GEMMA_WEIGHT_T
using WeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
@ -51,9 +55,9 @@ constexpr bool kSystemPrompt = false;
struct KVCache {
hwy::AlignedFreeUniquePtr<float[]>
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
key_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
hwy::AlignedFreeUniquePtr<float[]>
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
value_cache; // batch_size * kSeqLen * kLayers * kKVHeads * kQKVDim
};
// Model variants: see configs.h for details.
@ -61,9 +65,9 @@ enum class Model { GEMMA_2B, GEMMA_7B };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
static std::string ToLower(const std::string &text) {
static std::string ToLower(const std::string& text) {
std::string result = text;
std::transform(begin(result), end(result), begin(result),
[](unsigned char c) { return std::tolower(c); });
@ -89,7 +93,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}
// Returns error string or nullptr if OK.
const char *Validate() const {
const char* Validate() const {
const std::string model_type_lc = ToLower(model_type);
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
@ -111,11 +115,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}
Path tokenizer;
Path model; // uncompressed weights OR
Path cache; // compressed weights
Path model; // uncompressed weights OR
Path cache; // compressed weights
std::string model_type;
template <class Visitor> void ForEach(const Visitor &visitor) {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file. (required)");
visitor(
@ -138,10 +143,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
struct GemmaInterface;
struct Gemma {
Gemma(const LoaderArgs &args, hwy::ThreadPool &pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor &Tokenizer() const;
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
@ -153,7 +158,7 @@ using StreamFunc = std::function<bool(int, float)>;
using AcceptFunc = std::function<bool(int)>;
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
size_t max_tokens;
size_t max_generated_tokens;
@ -163,7 +168,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn;
// Returns error string or nullptr if OK.
const char *Validate() const {
const char* Validate() const {
if (max_tokens > gcpp::kSeqLen) {
return "max_tokens is larger than the maximum sequence length (see "
"configs.h).";
@ -175,7 +180,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
return nullptr;
}
template <class Visitor> void ForEach(const Visitor &visitor) {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(max_tokens, "max_tokens", size_t{3072},
"Maximum number of tokens in prompt + generation.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
@ -186,19 +192,20 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode (if 0, this clears the KV cache after every "
"interaction without quitting)\n Default = 0 (conversation resets every turn)");
"interaction without quitting)\n Default = 0 (conversation "
"resets every turn)");
}
};
void GenerateGemma(Gemma &gemma, const InferenceArgs &args,
const std::vector<int> &prompt, size_t start_pos,
hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool,
const StreamFunc &stream_token,
const AcceptFunc &accept_token, std::mt19937 &g,
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& g,
int verbosity);
constexpr int EOS_ID = 1;
} // namespace gcpp
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_H_