diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 2c15986..d02dece 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -13,15 +13,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Argument parsing for Gemma. +// Shared between various frontends. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#include #include +#include #include +#include "compression/io.h" // Path #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma @@ -32,66 +35,174 @@ namespace gcpp { -// Arguments related to inference: sampling, text etc. -struct InferenceArgs : public ArgsBase { - // Arguments for getc-like interfaces - size_t max_tokens; - size_t max_generated_tokens; - float temperature; - size_t top_k; - float top_p; - float min_p; - int repeat_penalty_power; - float repeat_penalty_presence; - float repeat_penalty_decay; - float repeat_penalty_range; +struct LoaderArgs : public ArgsBase { + LoaderArgs(int argc, char* argv[], bool validate = true) { + InitAndParse(argc, argv); - // Batch configuration: - size_t prefill_tbatch_size; - size_t decode_tbatch_size; + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } + } + LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, + const std::string& model, bool validate = true) { + Init(); // Init sets to defaults, so assignments must come after Init(). + tokenizer.path = tokenizer_path; + weights.path = weights_path; + model_type_str = model; - // Non-interactive mode prompt - std::string prompt; - std::string eot_line; + if (validate) { + if (const char* error = Validate()) { + HWY_ABORT("Invalid args: %s", error); + } + } + }; + + // Returns error string or nullptr if OK. + const char* Validate() { + if (weights.path.empty()) { + return "Missing --weights flag, a file for the model weights."; + } + if (!weights.Exists()) { + return "Can't open file specified with --weights flag."; + } + info_.model = Model::UNKNOWN; + info_.wrapping = PromptWrapping::GEMMA_PT; + info_.weight = Type::kUnknown; + if (!model_type_str.empty()) { + const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model, + info_.wrapping); + if (err != nullptr) return err; + } + if (!weight_type_str.empty()) { + const char* err = ParseType(weight_type_str, info_.weight); + if (err != nullptr) return err; + } + if (!tokenizer.path.empty()) { + if (!tokenizer.Exists()) { + return "Can't open file specified with --tokenizer flag."; + } + } + // model_type and tokenizer must be either both present or both absent. + // Further checks happen on weight loading. + if (model_type_str.empty() != tokenizer.path.empty()) { + return "Missing or extra flags for model_type or tokenizer."; + } + return nullptr; + } + + Path tokenizer; + Path weights; // weights file location + Path compressed_weights; + std::string model_type_str; + std::string weight_type_str; template - void ForEach(Visitor& visitor) { - // Each line specifies a variable member, its name, default value, and help. - visitor(max_tokens, "max_tokens", size_t{50}, - "Maximum number of total tokens including prompt (0=no limit).", 1); - visitor(max_generated_tokens, "max_generated_tokens", size_t{512}, - "Maximum number of generated tokens (not including prompt) (0=no " - "limit).", - 1); - visitor(temperature, "temperature", 1.0f, - "Temperature (randomness) for logits.", 1); - visitor(top_k, "top_k", size_t{40}, - "Number of highest-probability tokens to consider (0=unlimited).", - 1); - visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).", - 1); - visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).", - 1); - visitor( - repeat_penalty_power, "repeat_penalty_power", 1, - "Penalty power (1=standard frequentist penalty). If 0, skips penalty " - "computation.", - 1); - visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f, - "Penalty for token presence regardless of frequency (additive).", - 1); - visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f, - "Penalty for token n positions ago is decayed by " - "power(repeat_penalty_decay, n).", - 1); - visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f, - "Penalty fades out near the end of range (tokens)", 1); + void ForEach(const Visitor& visitor) { + visitor(tokenizer, "tokenizer", Path(), + "Path name of tokenizer model file."); + visitor(weights, "weights", Path(), + "Path name of model weights (.sbs) file.\n Required argument.\n"); + visitor(compressed_weights, "compressed_weights", Path(), + "Deprecated alias for --weights."); + visitor(model_type_str, "model", std::string(), + "Model type, see common.cc for valid values.\n"); + visitor(weight_type_str, "weight_type", std::string("sfp"), + "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); + } - // Batch configuration: - visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2}, - "Token batch size for prefill; <= 32", 2); - visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1}, - "Token batch size for decode (only 1 currently supported)", 2); + // Uninitialized before Validate, must call after that. + const ModelInfo& Info() const { return info_; } + + private: + ModelInfo info_; +}; + +// `env` must remain valid for the lifetime of the Gemma. +static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) { + if (Type::kUnknown == loader.Info().weight || + Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { + // New weights file format doesn't need tokenizer path or model/weightinfo. + return Gemma(loader.weights, env); + } + return Gemma(loader.tokenizer, loader.weights, loader.Info(), env); +} + +// `env` must remain valid for the lifetime of the Gemma. +static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, + MatMulEnv& env) { + if (Type::kUnknown == loader.Info().weight || + Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { + // New weights file format doesn't need tokenizer path or model/weight info. + return std::make_unique(loader.weights, env); + } + return std::make_unique(loader.tokenizer, loader.weights, + loader.Info(), env); +} + +struct InferenceArgs : public ArgsBase { + InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + InferenceArgs() { Init(); }; + + int verbosity; + + size_t max_generated_tokens; + + size_t prefill_tbatch_size; + size_t decode_qbatch_size; + + float temperature; + size_t top_k; + bool deterministic; + bool multiturn; + Path image_file; + + std::string prompt; // Added prompt flag for non-interactive mode + std::string eot_line; + + // Returns error string or nullptr if OK. + const char* Validate() const { + if (max_generated_tokens > gcpp::kSeqLen) { + return "max_generated_tokens is larger than the maximum sequence length " + "(see configs.h)."; + } + return nullptr; + } + + template + void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 1); // Changed verbosity level to 1 since it's user-facing + + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + "Maximum number of tokens to generate."); + + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, + "Prefill: max tokens per batch."); + visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, + "Decode: max queries per batch."); + + visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from", + 2); + visitor(deterministic, "deterministic", false, + "Make top-k sampling deterministic", 2); + visitor(multiturn, "multiturn", false, + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " + "resets every turn)"); + visitor(image_file, "image_file", Path(), "Image file to load."); + + visitor(prompt, "prompt", std::string(""), + "Initial prompt for non-interactive mode. When specified, " + "generates a response" + " and exits.", + 1); // Added as user-facing option visitor( eot_line, "eot_line", std::string(""), @@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase { "When you specify this, the prompt will be all lines " "before the line where only the given string appears.\n Default = " "When a newline is encountered, that signals the end of the turn.", - 1); - - // Non-interactive mode prompt - visitor(prompt, "prompt", std::string(""), - "Prompt to use in non-interactive mode", 1); + 2); } - const char* Validate() const { - if (max_generated_tokens == 0 && max_tokens == 0) { - return "At least one of max_tokens and max_generated_tokens must be > 0"; + void CopyTo(RuntimeConfig& runtime_config) const { + runtime_config.max_generated_tokens = max_generated_tokens; + runtime_config.prefill_tbatch_size = prefill_tbatch_size; + runtime_config.decode_qbatch_size = decode_qbatch_size; + if (prefill_tbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + prefill_tbatch_size, MMStorage::kMaxM); } - if (temperature <= 0.0) { - return "Temperature must be > 0.0"; + if (decode_qbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + decode_qbatch_size, MMStorage::kMaxM); } - if (prefill_tbatch_size > 32) { - return "prefill_tbatch_size must be <= 32"; - } - if (decode_tbatch_size != 1) { - return "decode_tbatch_size must be 1"; - } - return nullptr; - } -}; -// Arguments related to model weights. -struct LoaderArgs : public ArgsBase { - Path model_path; // Path to directory containing the weights - Path tokenizer; // Optional: can be derived from model_path - bool model_is_gemma2; - Gemma::Config::WeightFormat weight_format; - - template - void ForEach(Visitor& visitor) { - // Each line specifies a variable member, its name, default value, and help. - visitor(model_path, "model", Path{}, - "Directory containing weights or config file from `gemma.cpp " - "convert`.", - 0); - visitor(tokenizer, "tokenizer", Path{}, - "Optional path to tokenizer.model; if empty, looks in model_path.", - 2); - visitor(model_is_gemma2, "model_is_gemma2", false, - "Whether the model is a Gemma 2 model", 1); - visitor(weight_format, "format", Gemma::Config::kBfloat16, - "Model weights format: 0=F32, 1=F16, 2=BF16", 2); - } - - const char* Validate() const { - if (model_path.path.empty()) { - return "Empty model path"; - } - if (weight_format != Gemma::Config::kBfloat16 && - weight_format != Gemma::Config::kFloat16 && - weight_format != Gemma::Config::kFloat32) { - return "Invalid weight format"; - } - return nullptr; - } -}; - -// Threading-related arguments. -struct ThreadingArgs : public ArgsBase { - size_t num_threads; - Tristate pin_threads; - Tristate use_spinning; - int verbosity; - - template - void ForEach(Visitor& visitor) { - visitor(num_threads, "threads", size_t{0}, - "Number of threads (0=auto, half of logical cores)", 1); - visitor(pin_threads, "pin_threads", Tristate::kDefault, - "Set to true/false to force enable/disable thread pinning.", 2); - visitor(use_spinning, "use_spinning", Tristate::kDefault, - "Set to true/false to enable/disable thread spinning (typically " - "improves " - "performance but increases power usage)", - 2); - visitor(verbosity, "verbosity", 1, - "Controls printing of progress messages to stderr", 1); - } - - // Returns nullptr if OK, otherwise error message. - const char* Validate() const { return nullptr; } - - // Returns num_threads to use. - size_t NumThreadsToUse() const { - return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2 - : num_threads; - } -}; - -// Command-line arguments for PeftGemma and Gemma. -struct GemmaArgs : public ArgsBase { - InferenceArgs inference; - LoaderArgs loader; - ThreadingArgs threading; - // For collect_stats.cc: - Path output; - - bool trace_outputs; // For -ftrace and dump_csv.cc - bool trace_base; // For -ftrace - int time_it; // For time_it.cc - - template - void ForEach(Visitor& visitor) { - inference.ForEach(visitor); - loader.ForEach(visitor); - threading.ForEach(visitor); - - visitor(output, "output", Path{}, "Where to write CSV data / stats", 2); - visitor(trace_outputs, "trace_outputs", false, "For tracing", 2); - visitor(trace_base, "trace_base", false, "For tracing", 2); - visitor(time_it, "time_it", 0, "For benchmarks", 2); + runtime_config.temperature = temperature; + runtime_config.top_k = top_k; } }; } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 32eb5ff..b7e8fa1 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } -// New GetPrompt function that accepts InferenceArgs -std::string GetPrompt(const InferenceArgs& inference, int verbosity, - size_t turn) { - // Check for command-line prompt first +// Get prompt either from interactive input or command line +std::string GetPrompt(const InferenceArgs& inference) { + // If prompt is provided via command line, use that if (!inference.prompt.empty()) { return inference.prompt; } - // Use the existing function for interactive mode - return GetPrompt(std::cin, verbosity, inference.eot_line); + // Otherwise get interactive prompt + return GetPrompt(std::cin, inference.verbosity, inference.eot_line); } // The main Read-Eval-Print Loop. @@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::mt19937 gen; InitGenerator(inference, gen); - // Add flag to track non-interactive mode - bool non_interactive_mode = !inference.prompt.empty(); - const bool have_image = !inference.image_file.path.empty(); Image image; ImageTokens image_tokens; @@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, tokens_generated_this_turn = 0; // Read prompt and handle special commands. - std::string prompt_string = - GetPrompt(inference, inference.verbosity, abs_pos); + std::string prompt_string = GetPrompt(inference); - if (!std::cin && !non_interactive_mode) return; + if (!std::cin && inference.prompt.empty()) return; // If !eot_line.empty(), we append \n, so only look at the first 2 chars. - if (!non_interactive_mode && prompt_string.size() >= 2 && + if (inference.prompt.empty() && prompt_string.size() >= 2 && prompt_string[0] == '%') { if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { @@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } } - if (!non_interactive_mode && prompt_string.empty()) { + if (inference.prompt.empty() && prompt_string.empty()) { std::cout << "Use '%q' to quit.\n"; continue; } - std::vector prompt; + size_t prompt_size = 0; + size_t prefix_end = 0; + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } + } + + // Set up runtime config. + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .stream_token = stream_token, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), @@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, prompt_size = prompt.size(); } - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } - - // Set up runtime config. - TimingInfo timing_info = {.verbosity = inference.verbosity}; - RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = inference.verbosity, - .stream_token = stream_token, - .use_spinning = threading.spin}; - inference.CopyTo(runtime_config); - size_t prefix_end = 0; - if (have_image) { runtime_config.image_tokens = &image_tokens; prompt_size = prompt.size(); @@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "\n\n"; // Break the loop if in non-interactive mode - if (non_interactive_mode) { + if (inference.prompt.empty()) { break; }