Add non-interactive mode support

- Added prompt flag to InferenceArgs for non-interactive mode
- Set user-facing options to verbosity level 1
- Fixed prompt_size declaration and variable ordering in run.cc
- Properly set prompt_size after WrapAndTokenize calls
- Moved kVerboseLogTokens block after prompt_size is set
This commit is contained in:
prajwalc22 2025-04-16 16:26:52 +05:30
parent cbf179990f
commit 8246e49199
2 changed files with 213 additions and 199 deletions

View File

@ -13,15 +13,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Argument parsing for Gemma.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#include <stddef.h>
#include <stdio.h>
#include <memory>
#include <string>
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma
@ -32,66 +35,174 @@
namespace gcpp {
// Arguments related to inference: sampling, text etc.
struct InferenceArgs : public ArgsBase<InferenceArgs> {
// Arguments for getc-like interfaces
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
size_t top_k;
float top_p;
float min_p;
int repeat_penalty_power;
float repeat_penalty_presence;
float repeat_penalty_decay;
float repeat_penalty_range;
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[], bool validate = true) {
InitAndParse(argc, argv);
// Batch configuration:
size_t prefill_tbatch_size;
size_t decode_tbatch_size;
if (validate) {
if (const char* error = Validate()) {
HWY_ABORT("Invalid args: %s", error);
}
}
}
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
const std::string& model, bool validate = true) {
Init(); // Init sets to defaults, so assignments must come after Init().
tokenizer.path = tokenizer_path;
weights.path = weights_path;
model_type_str = model;
// Non-interactive mode prompt
std::string prompt;
std::string eot_line;
if (validate) {
if (const char* error = Validate()) {
HWY_ABORT("Invalid args: %s", error);
}
}
};
// Returns error string or nullptr if OK.
const char* Validate() {
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (!model_type_str.empty()) {
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping);
if (err != nullptr) return err;
}
if (!weight_type_str.empty()) {
const char* err = ParseType(weight_type_str, info_.weight);
if (err != nullptr) return err;
}
if (!tokenizer.path.empty()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
// model_type and tokenizer must be either both present or both absent.
// Further checks happen on weight loading.
if (model_type_str.empty() != tokenizer.path.empty()) {
return "Missing or extra flags for model_type or tokenizer.";
}
return nullptr;
}
Path tokenizer;
Path weights; // weights file location
Path compressed_weights;
std::string model_type_str;
std::string weight_type_str;
template <class Visitor>
void ForEach(Visitor& visitor) {
// Each line specifies a variable member, its name, default value, and help.
visitor(max_tokens, "max_tokens", size_t{50},
"Maximum number of total tokens including prompt (0=no limit).", 1);
visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
"Maximum number of generated tokens (not including prompt) (0=no "
"limit).",
1);
visitor(temperature, "temperature", 1.0f,
"Temperature (randomness) for logits.", 1);
visitor(top_k, "top_k", size_t{40},
"Number of highest-probability tokens to consider (0=unlimited).",
1);
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
1);
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
1);
visitor(
repeat_penalty_power, "repeat_penalty_power", 1,
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
"computation.",
1);
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
"Penalty for token presence regardless of frequency (additive).",
1);
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
"Penalty for token n positions ago is decayed by "
"power(repeat_penalty_decay, n).",
1);
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
"Penalty fades out near the end of range (tokens)", 1);
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
"Path name of tokenizer model file.");
visitor(weights, "weights", Path(),
"Path name of model weights (.sbs) file.\n Required argument.\n");
visitor(compressed_weights, "compressed_weights", Path(),
"Deprecated alias for --weights.");
visitor(model_type_str, "model", std::string(),
"Model type, see common.cc for valid values.\n");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
}
// Batch configuration:
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
"Token batch size for prefill; <= 32", 2);
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
"Token batch size for decode (only 1 currently supported)", 2);
// Uninitialized before Validate, must call after that.
const ModelInfo& Info() const { return info_; }
private:
ModelInfo info_;
};
// `env` must remain valid for the lifetime of the Gemma.
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weightinfo.
return Gemma(loader.weights, env);
}
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
}
// `env` must remain valid for the lifetime of the Gemma.
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weight info.
return std::make_unique<Gemma>(loader.weights, env);
}
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), env);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); };
int verbosity;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
size_t decode_qbatch_size;
float temperature;
size_t top_k;
bool deterministic;
bool multiturn;
Path image_file;
std::string prompt; // Added prompt flag for non-interactive mode
std::string eot_line;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_generated_tokens > gcpp::kSeqLen) {
return "max_generated_tokens is larger than the maximum sequence length "
"(see configs.h).";
}
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
1); // Changed verbosity level to 1 since it's user-facing
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
"Prefill: max tokens per batch.");
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
"generates a response"
" and exits.",
1); // Added as user-facing option
visitor(
eot_line, "eot_line", std::string(""),
@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
1);
// Non-interactive mode prompt
visitor(prompt, "prompt", std::string(""),
"Prompt to use in non-interactive mode", 1);
2);
}
const char* Validate() const {
if (max_generated_tokens == 0 && max_tokens == 0) {
return "At least one of max_tokens and max_generated_tokens must be > 0";
void CopyTo(RuntimeConfig& runtime_config) const {
runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
}
if (temperature <= 0.0) {
return "Temperature must be > 0.0";
if (decode_qbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
decode_qbatch_size, MMStorage::kMaxM);
}
if (prefill_tbatch_size > 32) {
return "prefill_tbatch_size must be <= 32";
}
if (decode_tbatch_size != 1) {
return "decode_tbatch_size must be 1";
}
return nullptr;
}
};
// Arguments related to model weights.
struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path model_path; // Path to directory containing the weights
Path tokenizer; // Optional: can be derived from model_path
bool model_is_gemma2;
Gemma::Config::WeightFormat weight_format;
template <class Visitor>
void ForEach(Visitor& visitor) {
// Each line specifies a variable member, its name, default value, and help.
visitor(model_path, "model", Path{},
"Directory containing weights or config file from `gemma.cpp "
"convert`.",
0);
visitor(tokenizer, "tokenizer", Path{},
"Optional path to tokenizer.model; if empty, looks in model_path.",
2);
visitor(model_is_gemma2, "model_is_gemma2", false,
"Whether the model is a Gemma 2 model", 1);
visitor(weight_format, "format", Gemma::Config::kBfloat16,
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
}
const char* Validate() const {
if (model_path.path.empty()) {
return "Empty model path";
}
if (weight_format != Gemma::Config::kBfloat16 &&
weight_format != Gemma::Config::kFloat16 &&
weight_format != Gemma::Config::kFloat32) {
return "Invalid weight format";
}
return nullptr;
}
};
// Threading-related arguments.
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
size_t num_threads;
Tristate pin_threads;
Tristate use_spinning;
int verbosity;
template <class Visitor>
void ForEach(Visitor& visitor) {
visitor(num_threads, "threads", size_t{0},
"Number of threads (0=auto, half of logical cores)", 1);
visitor(pin_threads, "pin_threads", Tristate::kDefault,
"Set to true/false to force enable/disable thread pinning.", 2);
visitor(use_spinning, "use_spinning", Tristate::kDefault,
"Set to true/false to enable/disable thread spinning (typically "
"improves "
"performance but increases power usage)",
2);
visitor(verbosity, "verbosity", 1,
"Controls printing of progress messages to stderr", 1);
}
// Returns nullptr if OK, otherwise error message.
const char* Validate() const { return nullptr; }
// Returns num_threads to use.
size_t NumThreadsToUse() const {
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
: num_threads;
}
};
// Command-line arguments for PeftGemma and Gemma.
struct GemmaArgs : public ArgsBase<GemmaArgs> {
InferenceArgs inference;
LoaderArgs loader;
ThreadingArgs threading;
// For collect_stats.cc:
Path output;
bool trace_outputs; // For -ftrace and dump_csv.cc
bool trace_base; // For -ftrace
int time_it; // For time_it.cc
template <class Visitor>
void ForEach(Visitor& visitor) {
inference.ForEach(visitor);
loader.ForEach(visitor);
threading.ForEach(visitor);
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
visitor(trace_base, "trace_base", false, "For tracing", 2);
visitor(time_it, "time_it", 0, "For benchmarks", 2);
runtime_config.temperature = temperature;
runtime_config.top_k = top_k;
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string;
}
// New GetPrompt function that accepts InferenceArgs
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
size_t turn) {
// Check for command-line prompt first
// Get prompt either from interactive input or command line
std::string GetPrompt(const InferenceArgs& inference) {
// If prompt is provided via command line, use that
if (!inference.prompt.empty()) {
return inference.prompt;
}
// Use the existing function for interactive mode
return GetPrompt(std::cin, verbosity, inference.eot_line);
// Otherwise get interactive prompt
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
}
// The main Read-Eval-Print Loop.
@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::mt19937 gen;
InitGenerator(inference, gen);
// Add flag to track non-interactive mode
bool non_interactive_mode = !inference.prompt.empty();
const bool have_image = !inference.image_file.path.empty();
Image image;
ImageTokens image_tokens;
@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
tokens_generated_this_turn = 0;
// Read prompt and handle special commands.
std::string prompt_string =
GetPrompt(inference, inference.verbosity, abs_pos);
std::string prompt_string = GetPrompt(inference);
if (!std::cin && !non_interactive_mode) return;
if (!std::cin && inference.prompt.empty()) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (!non_interactive_mode && prompt_string.size() >= 2 &&
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
}
}
if (!non_interactive_mode && prompt_string.empty()) {
if (inference.prompt.empty() && prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n";
continue;
}
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
if (have_image) {
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
prompt_size = prompt.size();
}
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "\n\n";
// Break the loop if in non-interactive mode
if (non_interactive_mode) {
if (inference.prompt.empty()) {
break;
}