mirror of https://github.com/google/gemma.cpp.git
Add non-interactive mode support
- Added prompt flag to InferenceArgs for non-interactive mode - Set user-facing options to verbosity level 1 - Fixed prompt_size declaration and variable ordering in run.cc - Properly set prompt_size after WrapAndTokenize calls - Moved kVerboseLogTokens block after prompt_size is set
This commit is contained in:
parent
cbf179990f
commit
8246e49199
|
|
@ -13,15 +13,18 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Argument parsing for Gemma.
|
||||
// Shared between various frontends.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
|
|
@ -32,66 +35,174 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Arguments related to inference: sampling, text etc.
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
// Arguments for getc-like interfaces
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
float temperature;
|
||||
size_t top_k;
|
||||
float top_p;
|
||||
float min_p;
|
||||
int repeat_penalty_power;
|
||||
float repeat_penalty_presence;
|
||||
float repeat_penalty_decay;
|
||||
float repeat_penalty_range;
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||
InitAndParse(argc, argv);
|
||||
|
||||
// Batch configuration:
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_tbatch_size;
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||
const std::string& model, bool validate = true) {
|
||||
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||
tokenizer.path = tokenizer_path;
|
||||
weights.path = weights_path;
|
||||
model_type_str = model;
|
||||
|
||||
// Non-interactive mode prompt
|
||||
std::string prompt;
|
||||
std::string eot_line;
|
||||
if (validate) {
|
||||
if (const char* error = Validate()) {
|
||||
HWY_ABORT("Invalid args: %s", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the model weights.";
|
||||
}
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
info_.model = Model::UNKNOWN;
|
||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||
info_.weight = Type::kUnknown;
|
||||
if (!model_type_str.empty()) {
|
||||
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||
info_.wrapping);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!weight_type_str.empty()) {
|
||||
const char* err = ParseType(weight_type_str, info_.weight);
|
||||
if (err != nullptr) return err;
|
||||
}
|
||||
if (!tokenizer.path.empty()) {
|
||||
if (!tokenizer.Exists()) {
|
||||
return "Can't open file specified with --tokenizer flag.";
|
||||
}
|
||||
}
|
||||
// model_type and tokenizer must be either both present or both absent.
|
||||
// Further checks happen on weight loading.
|
||||
if (model_type_str.empty() != tokenizer.path.empty()) {
|
||||
return "Missing or extra flags for model_type or tokenizer.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path tokenizer;
|
||||
Path weights; // weights file location
|
||||
Path compressed_weights;
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
// Each line specifies a variable member, its name, default value, and help.
|
||||
visitor(max_tokens, "max_tokens", size_t{50},
|
||||
"Maximum number of total tokens including prompt (0=no limit).", 1);
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
|
||||
"Maximum number of generated tokens (not including prompt) (0=no "
|
||||
"limit).",
|
||||
1);
|
||||
visitor(temperature, "temperature", 1.0f,
|
||||
"Temperature (randomness) for logits.", 1);
|
||||
visitor(top_k, "top_k", size_t{40},
|
||||
"Number of highest-probability tokens to consider (0=unlimited).",
|
||||
1);
|
||||
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
|
||||
1);
|
||||
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
|
||||
1);
|
||||
visitor(
|
||||
repeat_penalty_power, "repeat_penalty_power", 1,
|
||||
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
|
||||
"computation.",
|
||||
1);
|
||||
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
|
||||
"Penalty for token presence regardless of frequency (additive).",
|
||||
1);
|
||||
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
|
||||
"Penalty for token n positions ago is decayed by "
|
||||
"power(repeat_penalty_decay, n).",
|
||||
1);
|
||||
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
|
||||
"Penalty fades out near the end of range (tokens)", 1);
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path name of tokenizer model file.");
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Deprecated alias for --weights.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type, see common.cc for valid values.\n");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||
}
|
||||
|
||||
// Batch configuration:
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
|
||||
"Token batch size for prefill; <= 32", 2);
|
||||
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
|
||||
"Token batch size for decode (only 1 currently supported)", 2);
|
||||
// Uninitialized before Validate, must call after that.
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
|
||||
private:
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
||||
return Gemma(loader.weights, env);
|
||||
}
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
|
||||
}
|
||||
|
||||
// `env` must remain valid for the lifetime of the Gemma.
|
||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||
MatMulEnv& env) {
|
||||
if (Type::kUnknown == loader.Info().weight ||
|
||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||
// New weights file format doesn't need tokenizer path or model/weight info.
|
||||
return std::make_unique<Gemma>(loader.weights, env);
|
||||
}
|
||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||
loader.Info(), env);
|
||||
}
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
int verbosity;
|
||||
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
size_t decode_qbatch_size;
|
||||
|
||||
float temperature;
|
||||
size_t top_k;
|
||||
bool deterministic;
|
||||
bool multiturn;
|
||||
Path image_file;
|
||||
|
||||
std::string prompt; // Added prompt flag for non-interactive mode
|
||||
std::string eot_line;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||
return "max_generated_tokens is larger than the maximum sequence length "
|
||||
"(see configs.h).";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
1); // Changed verbosity level to 1 since it's user-facing
|
||||
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||
"Prefill: max tokens per batch.");
|
||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||
"Decode: max queries per batch.");
|
||||
|
||||
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
||||
2);
|
||||
visitor(deterministic, "deterministic", false,
|
||||
"Make top-k sampling deterministic", 2);
|
||||
visitor(multiturn, "multiturn", false,
|
||||
"Multiturn mode\n 0 = clear KV cache after every "
|
||||
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Initial prompt for non-interactive mode. When specified, "
|
||||
"generates a response"
|
||||
" and exits.",
|
||||
1); // Added as user-facing option
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
|
|
@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
1);
|
||||
|
||||
// Non-interactive mode prompt
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Prompt to use in non-interactive mode", 1);
|
||||
}
|
||||
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens == 0 && max_tokens == 0) {
|
||||
return "At least one of max_tokens and max_generated_tokens must be > 0";
|
||||
}
|
||||
if (temperature <= 0.0) {
|
||||
return "Temperature must be > 0.0";
|
||||
}
|
||||
if (prefill_tbatch_size > 32) {
|
||||
return "prefill_tbatch_size must be <= 32";
|
||||
}
|
||||
if (decode_tbatch_size != 1) {
|
||||
return "decode_tbatch_size must be 1";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Arguments related to model weights.
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
Path model_path; // Path to directory containing the weights
|
||||
Path tokenizer; // Optional: can be derived from model_path
|
||||
bool model_is_gemma2;
|
||||
Gemma::Config::WeightFormat weight_format;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
// Each line specifies a variable member, its name, default value, and help.
|
||||
visitor(model_path, "model", Path{},
|
||||
"Directory containing weights or config file from `gemma.cpp "
|
||||
"convert`.",
|
||||
0);
|
||||
visitor(tokenizer, "tokenizer", Path{},
|
||||
"Optional path to tokenizer.model; if empty, looks in model_path.",
|
||||
2);
|
||||
visitor(model_is_gemma2, "model_is_gemma2", false,
|
||||
"Whether the model is a Gemma 2 model", 1);
|
||||
visitor(weight_format, "format", Gemma::Config::kBfloat16,
|
||||
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
|
||||
}
|
||||
|
||||
const char* Validate() const {
|
||||
if (model_path.path.empty()) {
|
||||
return "Empty model path";
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
prefill_tbatch_size, MMStorage::kMaxM);
|
||||
}
|
||||
if (weight_format != Gemma::Config::kBfloat16 &&
|
||||
weight_format != Gemma::Config::kFloat16 &&
|
||||
weight_format != Gemma::Config::kFloat32) {
|
||||
return "Invalid weight format";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Threading-related arguments.
|
||||
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
||||
size_t num_threads;
|
||||
Tristate pin_threads;
|
||||
Tristate use_spinning;
|
||||
int verbosity;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
visitor(num_threads, "threads", size_t{0},
|
||||
"Number of threads (0=auto, half of logical cores)", 1);
|
||||
visitor(pin_threads, "pin_threads", Tristate::kDefault,
|
||||
"Set to true/false to force enable/disable thread pinning.", 2);
|
||||
visitor(use_spinning, "use_spinning", Tristate::kDefault,
|
||||
"Set to true/false to enable/disable thread spinning (typically "
|
||||
"improves "
|
||||
"performance but increases power usage)",
|
||||
2);
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Controls printing of progress messages to stderr", 1);
|
||||
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||
HWY_ABORT(
|
||||
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
decode_qbatch_size, MMStorage::kMaxM);
|
||||
}
|
||||
|
||||
// Returns nullptr if OK, otherwise error message.
|
||||
const char* Validate() const { return nullptr; }
|
||||
|
||||
// Returns num_threads to use.
|
||||
size_t NumThreadsToUse() const {
|
||||
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
|
||||
: num_threads;
|
||||
}
|
||||
};
|
||||
|
||||
// Command-line arguments for PeftGemma and Gemma.
|
||||
struct GemmaArgs : public ArgsBase<GemmaArgs> {
|
||||
InferenceArgs inference;
|
||||
LoaderArgs loader;
|
||||
ThreadingArgs threading;
|
||||
// For collect_stats.cc:
|
||||
Path output;
|
||||
|
||||
bool trace_outputs; // For -ftrace and dump_csv.cc
|
||||
bool trace_base; // For -ftrace
|
||||
int time_it; // For time_it.cc
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(Visitor& visitor) {
|
||||
inference.ForEach(visitor);
|
||||
loader.ForEach(visitor);
|
||||
threading.ForEach(visitor);
|
||||
|
||||
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
|
||||
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
|
||||
visitor(trace_base, "trace_base", false, "For tracing", 2);
|
||||
visitor(time_it, "time_it", 0, "For benchmarks", 2);
|
||||
runtime_config.temperature = temperature;
|
||||
runtime_config.top_k = top_k;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
57
gemma/run.cc
57
gemma/run.cc
|
|
@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
return prompt_string;
|
||||
}
|
||||
|
||||
// New GetPrompt function that accepts InferenceArgs
|
||||
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
|
||||
size_t turn) {
|
||||
// Check for command-line prompt first
|
||||
// Get prompt either from interactive input or command line
|
||||
std::string GetPrompt(const InferenceArgs& inference) {
|
||||
// If prompt is provided via command line, use that
|
||||
if (!inference.prompt.empty()) {
|
||||
return inference.prompt;
|
||||
}
|
||||
|
||||
// Use the existing function for interactive mode
|
||||
return GetPrompt(std::cin, verbosity, inference.eot_line);
|
||||
// Otherwise get interactive prompt
|
||||
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
|
|
@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::mt19937 gen;
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
// Add flag to track non-interactive mode
|
||||
bool non_interactive_mode = !inference.prompt.empty();
|
||||
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
|
|
@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
tokens_generated_this_turn = 0;
|
||||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string =
|
||||
GetPrompt(inference, inference.verbosity, abs_pos);
|
||||
std::string prompt_string = GetPrompt(inference);
|
||||
|
||||
if (!std::cin && !non_interactive_mode) return;
|
||||
if (!std::cin && inference.prompt.empty()) return;
|
||||
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (!non_interactive_mode && prompt_string.size() >= 2 &&
|
||||
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
|
||||
prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
|
|
@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
}
|
||||
|
||||
if (!non_interactive_mode && prompt_string.empty()) {
|
||||
if (inference.prompt.empty() && prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> prompt;
|
||||
size_t prompt_size = 0;
|
||||
size_t prefix_end = 0;
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
|
|
@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
|
|
@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cout << "\n\n";
|
||||
|
||||
// Break the loop if in non-interactive mode
|
||||
if (non_interactive_mode) {
|
||||
if (inference.prompt.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue