Add non-interactive mode support

- Added prompt flag to InferenceArgs for non-interactive mode
- Set user-facing options to verbosity level 1
- Fixed prompt_size declaration and variable ordering in run.cc
- Properly set prompt_size after WrapAndTokenize calls
- Moved kVerboseLogTokens block after prompt_size is set
This commit is contained in:
prajwalc22 2025-04-16 16:26:52 +05:30
parent cbf179990f
commit 8246e49199
2 changed files with 213 additions and 199 deletions

View File

@ -13,15 +13,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Argument parsing for Gemma. // Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <string> #include <string>
#include "compression/io.h" // Path
#include "compression/shared.h" #include "compression/shared.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
@ -32,66 +35,174 @@
namespace gcpp { namespace gcpp {
// Arguments related to inference: sampling, text etc. struct LoaderArgs : public ArgsBase<LoaderArgs> {
struct InferenceArgs : public ArgsBase<InferenceArgs> { LoaderArgs(int argc, char* argv[], bool validate = true) {
// Arguments for getc-like interfaces InitAndParse(argc, argv);
size_t max_tokens;
size_t max_generated_tokens;
float temperature;
size_t top_k;
float top_p;
float min_p;
int repeat_penalty_power;
float repeat_penalty_presence;
float repeat_penalty_decay;
float repeat_penalty_range;
// Batch configuration: if (validate) {
size_t prefill_tbatch_size; if (const char* error = Validate()) {
size_t decode_tbatch_size; HWY_ABORT("Invalid args: %s", error);
}
}
}
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
const std::string& model, bool validate = true) {
Init(); // Init sets to defaults, so assignments must come after Init().
tokenizer.path = tokenizer_path;
weights.path = weights_path;
model_type_str = model;
// Non-interactive mode prompt if (validate) {
std::string prompt; if (const char* error = Validate()) {
std::string eot_line; HWY_ABORT("Invalid args: %s", error);
}
}
};
// Returns error string or nullptr if OK.
const char* Validate() {
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (!model_type_str.empty()) {
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping);
if (err != nullptr) return err;
}
if (!weight_type_str.empty()) {
const char* err = ParseType(weight_type_str, info_.weight);
if (err != nullptr) return err;
}
if (!tokenizer.path.empty()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
// model_type and tokenizer must be either both present or both absent.
// Further checks happen on weight loading.
if (model_type_str.empty() != tokenizer.path.empty()) {
return "Missing or extra flags for model_type or tokenizer.";
}
return nullptr;
}
Path tokenizer;
Path weights; // weights file location
Path compressed_weights;
std::string model_type_str;
std::string weight_type_str;
template <class Visitor> template <class Visitor>
void ForEach(Visitor& visitor) { void ForEach(const Visitor& visitor) {
// Each line specifies a variable member, its name, default value, and help. visitor(tokenizer, "tokenizer", Path(),
visitor(max_tokens, "max_tokens", size_t{50}, "Path name of tokenizer model file.");
"Maximum number of total tokens including prompt (0=no limit).", 1); visitor(weights, "weights", Path(),
visitor(max_generated_tokens, "max_generated_tokens", size_t{512}, "Path name of model weights (.sbs) file.\n Required argument.\n");
"Maximum number of generated tokens (not including prompt) (0=no " visitor(compressed_weights, "compressed_weights", Path(),
"limit).", "Deprecated alias for --weights.");
1); visitor(model_type_str, "model", std::string(),
visitor(temperature, "temperature", 1.0f, "Model type, see common.cc for valid values.\n");
"Temperature (randomness) for logits.", 1); visitor(weight_type_str, "weight_type", std::string("sfp"),
visitor(top_k, "top_k", size_t{40}, "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
"Number of highest-probability tokens to consider (0=unlimited).", }
1);
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
1);
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
1);
visitor(
repeat_penalty_power, "repeat_penalty_power", 1,
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
"computation.",
1);
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
"Penalty for token presence regardless of frequency (additive).",
1);
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
"Penalty for token n positions ago is decayed by "
"power(repeat_penalty_decay, n).",
1);
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
"Penalty fades out near the end of range (tokens)", 1);
// Batch configuration: // Uninitialized before Validate, must call after that.
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2}, const ModelInfo& Info() const { return info_; }
"Token batch size for prefill; <= 32", 2);
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1}, private:
"Token batch size for decode (only 1 currently supported)", 2); ModelInfo info_;
};
// `env` must remain valid for the lifetime of the Gemma.
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weightinfo.
return Gemma(loader.weights, env);
}
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
}
// `env` must remain valid for the lifetime of the Gemma.
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
MatMulEnv& env) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weight info.
return std::make_unique<Gemma>(loader.weights, env);
}
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), env);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); };
int verbosity;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
size_t decode_qbatch_size;
float temperature;
size_t top_k;
bool deterministic;
bool multiturn;
Path image_file;
std::string prompt; // Added prompt flag for non-interactive mode
std::string eot_line;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_generated_tokens > gcpp::kSeqLen) {
return "max_generated_tokens is larger than the maximum sequence length "
"(see configs.h).";
}
return nullptr;
}
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
1); // Changed verbosity level to 1 since it's user-facing
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
"Prefill: max tokens per batch.");
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
2);
visitor(deterministic, "deterministic", false,
"Make top-k sampling deterministic", 2);
visitor(multiturn, "multiturn", false,
"Multiturn mode\n 0 = clear KV cache after every "
"interaction\n 1 = continue KV cache after every interaction\n "
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
"generates a response"
" and exits.",
1); // Added as user-facing option
visitor( visitor(
eot_line, "eot_line", std::string(""), eot_line, "eot_line", std::string(""),
@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"When you specify this, the prompt will be all lines " "When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = " "before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.", "When a newline is encountered, that signals the end of the turn.",
1);
// Non-interactive mode prompt
visitor(prompt, "prompt", std::string(""),
"Prompt to use in non-interactive mode", 1);
}
const char* Validate() const {
if (max_generated_tokens == 0 && max_tokens == 0) {
return "At least one of max_tokens and max_generated_tokens must be > 0";
}
if (temperature <= 0.0) {
return "Temperature must be > 0.0";
}
if (prefill_tbatch_size > 32) {
return "prefill_tbatch_size must be <= 32";
}
if (decode_tbatch_size != 1) {
return "decode_tbatch_size must be 1";
}
return nullptr;
}
};
// Arguments related to model weights.
struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path model_path; // Path to directory containing the weights
Path tokenizer; // Optional: can be derived from model_path
bool model_is_gemma2;
Gemma::Config::WeightFormat weight_format;
template <class Visitor>
void ForEach(Visitor& visitor) {
// Each line specifies a variable member, its name, default value, and help.
visitor(model_path, "model", Path{},
"Directory containing weights or config file from `gemma.cpp "
"convert`.",
0);
visitor(tokenizer, "tokenizer", Path{},
"Optional path to tokenizer.model; if empty, looks in model_path.",
2); 2);
visitor(model_is_gemma2, "model_is_gemma2", false,
"Whether the model is a Gemma 2 model", 1);
visitor(weight_format, "format", Gemma::Config::kBfloat16,
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
} }
const char* Validate() const { void CopyTo(RuntimeConfig& runtime_config) const {
if (model_path.path.empty()) { runtime_config.max_generated_tokens = max_generated_tokens;
return "Empty model path"; runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
} }
if (weight_format != Gemma::Config::kBfloat16 && if (decode_qbatch_size > MMStorage::kMaxM) {
weight_format != Gemma::Config::kFloat16 && HWY_ABORT(
weight_format != Gemma::Config::kFloat32) { "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
return "Invalid weight format"; "or increase the constant in MMStorage.\n",
} decode_qbatch_size, MMStorage::kMaxM);
return nullptr;
}
};
// Threading-related arguments.
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
size_t num_threads;
Tristate pin_threads;
Tristate use_spinning;
int verbosity;
template <class Visitor>
void ForEach(Visitor& visitor) {
visitor(num_threads, "threads", size_t{0},
"Number of threads (0=auto, half of logical cores)", 1);
visitor(pin_threads, "pin_threads", Tristate::kDefault,
"Set to true/false to force enable/disable thread pinning.", 2);
visitor(use_spinning, "use_spinning", Tristate::kDefault,
"Set to true/false to enable/disable thread spinning (typically "
"improves "
"performance but increases power usage)",
2);
visitor(verbosity, "verbosity", 1,
"Controls printing of progress messages to stderr", 1);
} }
// Returns nullptr if OK, otherwise error message. runtime_config.temperature = temperature;
const char* Validate() const { return nullptr; } runtime_config.top_k = top_k;
// Returns num_threads to use.
size_t NumThreadsToUse() const {
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
: num_threads;
}
};
// Command-line arguments for PeftGemma and Gemma.
struct GemmaArgs : public ArgsBase<GemmaArgs> {
InferenceArgs inference;
LoaderArgs loader;
ThreadingArgs threading;
// For collect_stats.cc:
Path output;
bool trace_outputs; // For -ftrace and dump_csv.cc
bool trace_base; // For -ftrace
int time_it; // For time_it.cc
template <class Visitor>
void ForEach(Visitor& visitor) {
inference.ForEach(visitor);
loader.ForEach(visitor);
threading.ForEach(visitor);
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
visitor(trace_base, "trace_base", false, "For tracing", 2);
visitor(time_it, "time_it", 0, "For benchmarks", 2);
} }
}; };
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string; return prompt_string;
} }
// New GetPrompt function that accepts InferenceArgs // Get prompt either from interactive input or command line
std::string GetPrompt(const InferenceArgs& inference, int verbosity, std::string GetPrompt(const InferenceArgs& inference) {
size_t turn) { // If prompt is provided via command line, use that
// Check for command-line prompt first
if (!inference.prompt.empty()) { if (!inference.prompt.empty()) {
return inference.prompt; return inference.prompt;
} }
// Use the existing function for interactive mode // Otherwise get interactive prompt
return GetPrompt(std::cin, verbosity, inference.eot_line); return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
} }
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::mt19937 gen; std::mt19937 gen;
InitGenerator(inference, gen); InitGenerator(inference, gen);
// Add flag to track non-interactive mode
bool non_interactive_mode = !inference.prompt.empty();
const bool have_image = !inference.image_file.path.empty(); const bool have_image = !inference.image_file.path.empty();
Image image; Image image;
ImageTokens image_tokens; ImageTokens image_tokens;
@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
tokens_generated_this_turn = 0; tokens_generated_this_turn = 0;
// Read prompt and handle special commands. // Read prompt and handle special commands.
std::string prompt_string = std::string prompt_string = GetPrompt(inference);
GetPrompt(inference, inference.verbosity, abs_pos);
if (!std::cin && !non_interactive_mode) return; if (!std::cin && inference.prompt.empty()) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars. // If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (!non_interactive_mode && prompt_string.size() >= 2 && if (inference.prompt.empty() && prompt_string.size() >= 2 &&
prompt_string[0] == '%') { prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
} }
if (!non_interactive_mode && prompt_string.empty()) { if (inference.prompt.empty() && prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n"; std::cout << "Use '%q' to quit.\n";
continue; continue;
} }
std::vector<int> prompt; std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
if (have_image) { if (have_image) {
prompt = prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
prompt_size = prompt.size(); prompt_size = prompt.size();
} }
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
if (have_image) { if (have_image) {
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size(); prompt_size = prompt.size();
@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "\n\n"; std::cout << "\n\n";
// Break the loop if in non-interactive mode // Break the loop if in non-interactive mode
if (non_interactive_mode) { if (inference.prompt.empty()) {
break; break;
} }