mirror of https://github.com/google/gemma.cpp.git
Merge pull request #58 from google:dev-cleanup
PiperOrigin-RevId: 610942948
This commit is contained in:
commit
f4a14bfdf2
|
|
@ -0,0 +1 @@
|
|||
BasedOnStyle: Google
|
||||
|
|
@ -71,6 +71,18 @@ The implementation code is roughly split into 4 layers, from high to low level:
|
|||
4. Backend (`highway`) - Low-level hardware interface (SIMD in the case of
|
||||
highway) supporting the implementations in (3).
|
||||
|
||||
Besides these layers, supporting utilities are:
|
||||
|
||||
- `compression/` - model compression operations. The 8-bit switched floating
|
||||
point model conversion is here.
|
||||
- `util/` - command line argument handling and any other utilities.
|
||||
|
||||
## Style and Formatting
|
||||
|
||||
A `.clang-format` configuration is provided with our defaults, please run source
|
||||
files through `clang-format` (or a formatter that produces equivalent behavior)
|
||||
before finalizing PR for submission.
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
There are several compile-time flags to be aware of (note these may or may not
|
||||
|
|
|
|||
37
gemma.h
37
gemma.h
|
|
@ -26,15 +26,19 @@
|
|||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/compress.h" // SfpStream/NuqStream
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "configs.h" // kSeqLen
|
||||
// copybara:end
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "util/args.h" // ArgsBase
|
||||
// copybara:end
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
// copybara:import_next_line:sentencepiece
|
||||
#include "src/sentencepiece_processor.h"
|
||||
// copybara:end
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -61,9 +65,9 @@ enum class Model { GEMMA_2B, GEMMA_7B };
|
|||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
static std::string ToLower(const std::string &text) {
|
||||
static std::string ToLower(const std::string& text) {
|
||||
std::string result = text;
|
||||
std::transform(begin(result), end(result), begin(result),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
|
|
@ -89,7 +93,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char *Validate() const {
|
||||
const char* Validate() const {
|
||||
const std::string model_type_lc = ToLower(model_type);
|
||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||
|
|
@ -115,7 +119,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
Path cache; // compressed weights
|
||||
std::string model_type;
|
||||
|
||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file. (required)");
|
||||
visitor(
|
||||
|
|
@ -138,10 +143,10 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
struct GemmaInterface;
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const LoaderArgs &args, hwy::ThreadPool &pool);
|
||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
const sentencepiece::SentencePieceProcessor &Tokenizer() const;
|
||||
const sentencepiece::SentencePieceProcessor& Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
|
|
@ -153,7 +158,7 @@ using StreamFunc = std::function<bool(int, float)>;
|
|||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char *argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
|
@ -163,7 +168,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
bool multiturn;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char *Validate() const {
|
||||
const char* Validate() const {
|
||||
if (max_tokens > gcpp::kSeqLen) {
|
||||
return "max_tokens is larger than the maximum sequence length (see "
|
||||
"configs.h).";
|
||||
|
|
@ -175,7 +180,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor> void ForEach(const Visitor &visitor) {
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||
"Maximum number of tokens in prompt + generation.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
|
|
@ -186,15 +192,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode (if 0, this clears the KV cache after every "
|
||||
"interaction without quitting)\n Default = 0 (conversation resets every turn)");
|
||||
"interaction without quitting)\n Default = 0 (conversation "
|
||||
"resets every turn)");
|
||||
}
|
||||
};
|
||||
|
||||
void GenerateGemma(Gemma &gemma, const InferenceArgs &args,
|
||||
const std::vector<int> &prompt, size_t start_pos,
|
||||
hwy::ThreadPool &pool, hwy::ThreadPool &inner_pool,
|
||||
const StreamFunc &stream_token,
|
||||
const AcceptFunc &accept_token, std::mt19937 &g,
|
||||
void GenerateGemma(Gemma& gemma, const InferenceArgs& args,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& g,
|
||||
int verbosity);
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
|
|
|||
Loading…
Reference in New Issue