mirror of https://github.com/google/gemma.cpp.git
Merge pull request #87 from google:refactor-tidy
PiperOrigin-RevId: 615204427
This commit is contained in:
commit
0221956b2e
23
BUILD.bazel
23
BUILD.bazel
|
|
@ -46,17 +46,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "app",
|
|
||||||
hdrs = [
|
|
||||||
"util/app.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":args",
|
|
||||||
"@hwy//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -80,6 +69,18 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "app",
|
||||||
|
hdrs = [
|
||||||
|
"util/app.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":args",
|
||||||
|
":gemma_lib",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ if (BUILD_MODE STREQUAL "local")
|
||||||
# Relative path to gemma.cpp from examples/hello_world/build/
|
# Relative path to gemma.cpp from examples/hello_world/build/
|
||||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||||
else()
|
else()
|
||||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837)
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
||||||
endif()
|
endif()
|
||||||
FetchContent_MakeAvailable(gemma)
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,9 @@
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
// copybara:end
|
// copybara:end
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "util/app.h" // LoaderArgs
|
||||||
|
// copybara:end
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
std::vector<int> tokenize(
|
std::vector<int> tokenize(
|
||||||
|
|
|
||||||
118
gemma.h
118
gemma.h
|
|
@ -71,124 +71,6 @@ struct RuntimeConfig {
|
||||||
int verbosity;
|
int verbosity;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
||||||
|
|
||||||
static std::string ToLower(const std::string& text) {
|
|
||||||
std::string result = text;
|
|
||||||
std::transform(begin(result), end(result), begin(result),
|
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Model ModelType() const {
|
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
|
||||||
return gcpp::Model::GEMMA_2B;
|
|
||||||
} else {
|
|
||||||
return gcpp::Model::GEMMA_7B;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::ModelTraining ModelTraining() const {
|
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
|
||||||
return gcpp::ModelTraining::GEMMA_PT;
|
|
||||||
} else {
|
|
||||||
return gcpp::ModelTraining::GEMMA_IT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
|
||||||
const char* Validate() const {
|
|
||||||
const std::string model_type_lc = ToLower(model_type);
|
|
||||||
if (model_type.empty()) {
|
|
||||||
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
|
||||||
"2b-it, or 7b-it.";
|
|
||||||
}
|
|
||||||
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
|
||||||
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
|
||||||
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
|
||||||
"7b-it.";
|
|
||||||
}
|
|
||||||
if (tokenizer.path.empty()) {
|
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
|
||||||
}
|
|
||||||
if (compressed_weights.path.empty()) {
|
|
||||||
return "Missing --compressed_weights flag, a file for the compressed "
|
|
||||||
"model.";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Path tokenizer;
|
|
||||||
Path weights; // uncompressed weights file location
|
|
||||||
Path compressed_weights; // compressed weights file location
|
|
||||||
std::string model_type;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
|
||||||
"Path name of tokenizer model file.\n Required argument.");
|
|
||||||
visitor(
|
|
||||||
compressed_weights, "compressed_weights", Path(),
|
|
||||||
"Path name of compressed weights file, regenerated from `--weights` "
|
|
||||||
"file if "
|
|
||||||
"the compressed weights file does not exist.\n Required argument.");
|
|
||||||
visitor(model_type, "model", std::string(),
|
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
|
||||||
" Required argument.");
|
|
||||||
visitor(weights, "weights", Path(),
|
|
||||||
"Path name of model weights (.sbs) file. Only required if "
|
|
||||||
"compressed_weights file is not present and needs to be "
|
|
||||||
"regenerated. This parameter is only required for compressing"
|
|
||||||
"new model weight exports, otherwise it is not needed.");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
||||||
|
|
||||||
size_t max_tokens;
|
|
||||||
size_t max_generated_tokens;
|
|
||||||
|
|
||||||
float temperature;
|
|
||||||
bool deterministic;
|
|
||||||
bool multiturn;
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
|
||||||
const char* Validate() const {
|
|
||||||
if (max_tokens > gcpp::kSeqLen) {
|
|
||||||
return "max_tokens is larger than the maximum sequence length (see "
|
|
||||||
"configs.h).";
|
|
||||||
}
|
|
||||||
if (max_generated_tokens > max_tokens) {
|
|
||||||
return "Maximum number of generated tokens is larger than the maximum "
|
|
||||||
"total tokens.";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(max_tokens, "max_tokens", size_t{3072},
|
|
||||||
"Maximum number of tokens in prompt + generation.");
|
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
|
||||||
"Maximum number of tokens to generate.");
|
|
||||||
|
|
||||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
|
||||||
visitor(deterministic, "deterministic", false,
|
|
||||||
"Make top-k sampling deterministic", 2);
|
|
||||||
visitor(multiturn, "multiturn", false,
|
|
||||||
"Multiturn mode\n 0 = clear KV cache after every "
|
|
||||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
|
||||||
" Default : 0 (conversation "
|
|
||||||
"resets every turn)");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GemmaInterface;
|
struct GemmaInterface;
|
||||||
|
|
||||||
struct Gemma {
|
struct Gemma {
|
||||||
|
|
|
||||||
2
run.cc
2
run.cc
|
|
@ -119,7 +119,7 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||||
verbosity](int token, float) {
|
verbosity](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
++current_pos;
|
++current_pos;
|
||||||
if (current_pos <= prompt_size) {
|
if (current_pos < prompt_size) {
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
} else if (token == gcpp::EOS_ID) {
|
} else if (token == gcpp::EOS_ID) {
|
||||||
if (!args.multiturn) {
|
if (!args.multiturn) {
|
||||||
|
|
|
||||||
129
util/app.h
129
util/app.h
|
|
@ -18,10 +18,13 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
#if HWY_OS_LINUX
|
#if HWY_OS_LINUX
|
||||||
#include <sched.h>
|
#include <sched.h>
|
||||||
|
|
||||||
|
#include <cctype>
|
||||||
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
#include <cerrno> // IDE does not recognize errno.h as providing errno.
|
||||||
|
#include <string>
|
||||||
#endif
|
#endif
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -29,6 +32,14 @@
|
||||||
#include <algorithm> // std::clamp
|
#include <algorithm> // std::clamp
|
||||||
#include <thread> // NOLINT>
|
#include <thread> // NOLINT>
|
||||||
|
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "configs.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
|
// copybara:import_next_line:gemma_cpp
|
||||||
|
#include "gemma.h"
|
||||||
|
// copybara:end
|
||||||
|
|
||||||
// copybara:import_next_line:gemma_cpp
|
// copybara:import_next_line:gemma_cpp
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
// copybara:end
|
// copybara:end
|
||||||
|
|
@ -116,6 +127,124 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
static std::string ToLower(const std::string& text) {
|
||||||
|
std::string result = text;
|
||||||
|
std::transform(begin(result), end(result), begin(result),
|
||||||
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::Model ModelType() const {
|
||||||
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
|
if (model_type_lc == "2b-pt" || model_type_lc == "2b-it") {
|
||||||
|
return gcpp::Model::GEMMA_2B;
|
||||||
|
} else {
|
||||||
|
return gcpp::Model::GEMMA_7B;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gcpp::ModelTraining ModelTraining() const {
|
||||||
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
|
if (model_type_lc == "7b-pt" || model_type_lc == "2b-pt") {
|
||||||
|
return gcpp::ModelTraining::GEMMA_PT;
|
||||||
|
} else {
|
||||||
|
return gcpp::ModelTraining::GEMMA_IT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() const {
|
||||||
|
const std::string model_type_lc = ToLower(model_type);
|
||||||
|
if (model_type.empty()) {
|
||||||
|
return "Missing --model flag, need to specify either 2b-pt, 7b-pt, "
|
||||||
|
"2b-it, or 7b-it.";
|
||||||
|
}
|
||||||
|
if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" &&
|
||||||
|
model_type_lc != "2b-it" && model_type_lc != "7b-it") {
|
||||||
|
return "Model type must be 2b-pt, 7b-pt, 2b-it, or "
|
||||||
|
"7b-it.";
|
||||||
|
}
|
||||||
|
if (tokenizer.path.empty()) {
|
||||||
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
|
}
|
||||||
|
if (compressed_weights.path.empty()) {
|
||||||
|
return "Missing --compressed_weights flag, a file for the compressed "
|
||||||
|
"model.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Path tokenizer;
|
||||||
|
Path weights; // uncompressed weights file location
|
||||||
|
Path compressed_weights; // compressed weights file location
|
||||||
|
std::string model_type;
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
|
"Path name of tokenizer model file.\n Required argument.");
|
||||||
|
visitor(
|
||||||
|
compressed_weights, "compressed_weights", Path(),
|
||||||
|
"Path name of compressed weights file, regenerated from `--weights` "
|
||||||
|
"file if "
|
||||||
|
"the compressed weights file does not exist.\n Required argument.");
|
||||||
|
visitor(model_type, "model", std::string(),
|
||||||
|
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||||
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||||
|
" Required argument.");
|
||||||
|
visitor(weights, "weights", Path(),
|
||||||
|
"Path name of model weights (.sbs) file. Only required if "
|
||||||
|
"compressed_weights file is not present and needs to be "
|
||||||
|
"regenerated. This parameter is only required for compressing"
|
||||||
|
"new model weight exports, otherwise it is not needed.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
|
size_t max_tokens;
|
||||||
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
|
float temperature;
|
||||||
|
bool deterministic;
|
||||||
|
bool multiturn;
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() const {
|
||||||
|
if (max_tokens > gcpp::kSeqLen) {
|
||||||
|
return "max_tokens is larger than the maximum sequence length (see "
|
||||||
|
"configs.h).";
|
||||||
|
}
|
||||||
|
if (max_generated_tokens > max_tokens) {
|
||||||
|
return "Maximum number of generated tokens is larger than the maximum "
|
||||||
|
"total tokens.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(max_tokens, "max_tokens", size_t{3072},
|
||||||
|
"Maximum number of tokens in prompt + generation.");
|
||||||
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||||
|
visitor(deterministic, "deterministic", false,
|
||||||
|
"Make top-k sampling deterministic", 2);
|
||||||
|
visitor(multiturn, "multiturn", false,
|
||||||
|
"Multiturn mode\n 0 = clear KV cache after every "
|
||||||
|
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||||
|
" Default : 0 (conversation "
|
||||||
|
"resets every turn)");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue