mirror of https://github.com/google/gemma.cpp.git
Add non-interactive mode support
- Added prompt flag to InferenceArgs for non-interactive mode - Set user-facing options to verbosity level 1 - Fixed prompt_size declaration and variable ordering in run.cc - Properly set prompt_size after WrapAndTokenize calls - Moved kVerboseLogTokens block after prompt_size is set
This commit is contained in:
parent
cbf179990f
commit
8246e49199
|
|
@ -13,15 +13,18 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Argument parsing for Gemma.
|
// Shared between various frontends.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "compression/io.h" // Path
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
|
|
@ -32,66 +35,174 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Arguments related to inference: sampling, text etc.
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||||
// Arguments for getc-like interfaces
|
InitAndParse(argc, argv);
|
||||||
size_t max_tokens;
|
|
||||||
size_t max_generated_tokens;
|
|
||||||
float temperature;
|
|
||||||
size_t top_k;
|
|
||||||
float top_p;
|
|
||||||
float min_p;
|
|
||||||
int repeat_penalty_power;
|
|
||||||
float repeat_penalty_presence;
|
|
||||||
float repeat_penalty_decay;
|
|
||||||
float repeat_penalty_range;
|
|
||||||
|
|
||||||
// Batch configuration:
|
if (validate) {
|
||||||
size_t prefill_tbatch_size;
|
if (const char* error = Validate()) {
|
||||||
size_t decode_tbatch_size;
|
HWY_ABORT("Invalid args: %s", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||||
|
const std::string& model, bool validate = true) {
|
||||||
|
Init(); // Init sets to defaults, so assignments must come after Init().
|
||||||
|
tokenizer.path = tokenizer_path;
|
||||||
|
weights.path = weights_path;
|
||||||
|
model_type_str = model;
|
||||||
|
|
||||||
// Non-interactive mode prompt
|
if (validate) {
|
||||||
std::string prompt;
|
if (const char* error = Validate()) {
|
||||||
std::string eot_line;
|
HWY_ABORT("Invalid args: %s", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() {
|
||||||
|
if (weights.path.empty()) {
|
||||||
|
return "Missing --weights flag, a file for the model weights.";
|
||||||
|
}
|
||||||
|
if (!weights.Exists()) {
|
||||||
|
return "Can't open file specified with --weights flag.";
|
||||||
|
}
|
||||||
|
info_.model = Model::UNKNOWN;
|
||||||
|
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||||
|
info_.weight = Type::kUnknown;
|
||||||
|
if (!model_type_str.empty()) {
|
||||||
|
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||||
|
info_.wrapping);
|
||||||
|
if (err != nullptr) return err;
|
||||||
|
}
|
||||||
|
if (!weight_type_str.empty()) {
|
||||||
|
const char* err = ParseType(weight_type_str, info_.weight);
|
||||||
|
if (err != nullptr) return err;
|
||||||
|
}
|
||||||
|
if (!tokenizer.path.empty()) {
|
||||||
|
if (!tokenizer.Exists()) {
|
||||||
|
return "Can't open file specified with --tokenizer flag.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// model_type and tokenizer must be either both present or both absent.
|
||||||
|
// Further checks happen on weight loading.
|
||||||
|
if (model_type_str.empty() != tokenizer.path.empty()) {
|
||||||
|
return "Missing or extra flags for model_type or tokenizer.";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Path tokenizer;
|
||||||
|
Path weights; // weights file location
|
||||||
|
Path compressed_weights;
|
||||||
|
std::string model_type_str;
|
||||||
|
std::string weight_type_str;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
// Each line specifies a variable member, its name, default value, and help.
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
visitor(max_tokens, "max_tokens", size_t{50},
|
"Path name of tokenizer model file.");
|
||||||
"Maximum number of total tokens including prompt (0=no limit).", 1);
|
visitor(weights, "weights", Path(),
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{512},
|
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||||
"Maximum number of generated tokens (not including prompt) (0=no "
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"limit).",
|
"Deprecated alias for --weights.");
|
||||||
1);
|
visitor(model_type_str, "model", std::string(),
|
||||||
visitor(temperature, "temperature", 1.0f,
|
"Model type, see common.cc for valid values.\n");
|
||||||
"Temperature (randomness) for logits.", 1);
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
visitor(top_k, "top_k", size_t{40},
|
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||||
"Number of highest-probability tokens to consider (0=unlimited).",
|
}
|
||||||
1);
|
|
||||||
visitor(top_p, "top_p", 0.9f, "Top-p probability threshold (0.0=disabled).",
|
|
||||||
1);
|
|
||||||
visitor(min_p, "min_p", 0.0f, "Min-p probability threshold (0.0=disabled).",
|
|
||||||
1);
|
|
||||||
visitor(
|
|
||||||
repeat_penalty_power, "repeat_penalty_power", 1,
|
|
||||||
"Penalty power (1=standard frequentist penalty). If 0, skips penalty "
|
|
||||||
"computation.",
|
|
||||||
1);
|
|
||||||
visitor(repeat_penalty_presence, "repeat_penalty_presence", 0.0f,
|
|
||||||
"Penalty for token presence regardless of frequency (additive).",
|
|
||||||
1);
|
|
||||||
visitor(repeat_penalty_decay, "repeat_penalty_decay", 0.0f,
|
|
||||||
"Penalty for token n positions ago is decayed by "
|
|
||||||
"power(repeat_penalty_decay, n).",
|
|
||||||
1);
|
|
||||||
visitor(repeat_penalty_range, "repeat_penalty_range", 8.0f,
|
|
||||||
"Penalty fades out near the end of range (tokens)", 1);
|
|
||||||
|
|
||||||
// Batch configuration:
|
// Uninitialized before Validate, must call after that.
|
||||||
visitor(prefill_tbatch_size, "prefill_tbatch_size", size_t{2},
|
const ModelInfo& Info() const { return info_; }
|
||||||
"Token batch size for prefill; <= 32", 2);
|
|
||||||
visitor(decode_tbatch_size, "decode_tbatch_size", size_t{1},
|
private:
|
||||||
"Token batch size for decode (only 1 currently supported)", 2);
|
ModelInfo info_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// `env` must remain valid for the lifetime of the Gemma.
|
||||||
|
static inline Gemma CreateGemma(const LoaderArgs& loader, MatMulEnv& env) {
|
||||||
|
if (Type::kUnknown == loader.Info().weight ||
|
||||||
|
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||||
|
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
||||||
|
return Gemma(loader.weights, env);
|
||||||
|
}
|
||||||
|
return Gemma(loader.tokenizer, loader.weights, loader.Info(), env);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `env` must remain valid for the lifetime of the Gemma.
|
||||||
|
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||||
|
MatMulEnv& env) {
|
||||||
|
if (Type::kUnknown == loader.Info().weight ||
|
||||||
|
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||||
|
// New weights file format doesn't need tokenizer path or model/weight info.
|
||||||
|
return std::make_unique<Gemma>(loader.weights, env);
|
||||||
|
}
|
||||||
|
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||||
|
loader.Info(), env);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
InferenceArgs() { Init(); };
|
||||||
|
|
||||||
|
int verbosity;
|
||||||
|
|
||||||
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
|
size_t prefill_tbatch_size;
|
||||||
|
size_t decode_qbatch_size;
|
||||||
|
|
||||||
|
float temperature;
|
||||||
|
size_t top_k;
|
||||||
|
bool deterministic;
|
||||||
|
bool multiturn;
|
||||||
|
Path image_file;
|
||||||
|
|
||||||
|
std::string prompt; // Added prompt flag for non-interactive mode
|
||||||
|
std::string eot_line;
|
||||||
|
|
||||||
|
// Returns error string or nullptr if OK.
|
||||||
|
const char* Validate() const {
|
||||||
|
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||||
|
return "max_generated_tokens is larger than the maximum sequence length "
|
||||||
|
"(see configs.h).";
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Visitor>
|
||||||
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(verbosity, "verbosity", 1,
|
||||||
|
"Show verbose developer information\n 0 = only print generation "
|
||||||
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||||
|
"developer/debug info).\n Default = 1.",
|
||||||
|
1); // Changed verbosity level to 1 since it's user-facing
|
||||||
|
|
||||||
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||||
|
"Prefill: max tokens per batch.");
|
||||||
|
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||||
|
"Decode: max queries per batch.");
|
||||||
|
|
||||||
|
visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2);
|
||||||
|
visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from",
|
||||||
|
2);
|
||||||
|
visitor(deterministic, "deterministic", false,
|
||||||
|
"Make top-k sampling deterministic", 2);
|
||||||
|
visitor(multiturn, "multiturn", false,
|
||||||
|
"Multiturn mode\n 0 = clear KV cache after every "
|
||||||
|
"interaction\n 1 = continue KV cache after every interaction\n "
|
||||||
|
" Default : 0 (conversation "
|
||||||
|
"resets every turn)");
|
||||||
|
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||||
|
|
||||||
|
visitor(prompt, "prompt", std::string(""),
|
||||||
|
"Initial prompt for non-interactive mode. When specified, "
|
||||||
|
"generates a response"
|
||||||
|
" and exits.",
|
||||||
|
1); // Added as user-facing option
|
||||||
|
|
||||||
visitor(
|
visitor(
|
||||||
eot_line, "eot_line", std::string(""),
|
eot_line, "eot_line", std::string(""),
|
||||||
|
|
@ -99,123 +210,31 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"When you specify this, the prompt will be all lines "
|
"When you specify this, the prompt will be all lines "
|
||||||
"before the line where only the given string appears.\n Default = "
|
"before the line where only the given string appears.\n Default = "
|
||||||
"When a newline is encountered, that signals the end of the turn.",
|
"When a newline is encountered, that signals the end of the turn.",
|
||||||
1);
|
2);
|
||||||
|
|
||||||
// Non-interactive mode prompt
|
|
||||||
visitor(prompt, "prompt", std::string(""),
|
|
||||||
"Prompt to use in non-interactive mode", 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* Validate() const {
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||||
if (max_generated_tokens == 0 && max_tokens == 0) {
|
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||||
return "At least one of max_tokens and max_generated_tokens must be > 0";
|
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||||
|
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||||
|
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||||
|
HWY_ABORT(
|
||||||
|
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||||
|
"or increase the constant in MMStorage.\n",
|
||||||
|
prefill_tbatch_size, MMStorage::kMaxM);
|
||||||
}
|
}
|
||||||
if (temperature <= 0.0) {
|
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||||
return "Temperature must be > 0.0";
|
HWY_ABORT(
|
||||||
|
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||||
|
"or increase the constant in MMStorage.\n",
|
||||||
|
decode_qbatch_size, MMStorage::kMaxM);
|
||||||
}
|
}
|
||||||
if (prefill_tbatch_size > 32) {
|
|
||||||
return "prefill_tbatch_size must be <= 32";
|
|
||||||
}
|
|
||||||
if (decode_tbatch_size != 1) {
|
|
||||||
return "decode_tbatch_size must be 1";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Arguments related to model weights.
|
runtime_config.temperature = temperature;
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
runtime_config.top_k = top_k;
|
||||||
Path model_path; // Path to directory containing the weights
|
|
||||||
Path tokenizer; // Optional: can be derived from model_path
|
|
||||||
bool model_is_gemma2;
|
|
||||||
Gemma::Config::WeightFormat weight_format;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(Visitor& visitor) {
|
|
||||||
// Each line specifies a variable member, its name, default value, and help.
|
|
||||||
visitor(model_path, "model", Path{},
|
|
||||||
"Directory containing weights or config file from `gemma.cpp "
|
|
||||||
"convert`.",
|
|
||||||
0);
|
|
||||||
visitor(tokenizer, "tokenizer", Path{},
|
|
||||||
"Optional path to tokenizer.model; if empty, looks in model_path.",
|
|
||||||
2);
|
|
||||||
visitor(model_is_gemma2, "model_is_gemma2", false,
|
|
||||||
"Whether the model is a Gemma 2 model", 1);
|
|
||||||
visitor(weight_format, "format", Gemma::Config::kBfloat16,
|
|
||||||
"Model weights format: 0=F32, 1=F16, 2=BF16", 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* Validate() const {
|
|
||||||
if (model_path.path.empty()) {
|
|
||||||
return "Empty model path";
|
|
||||||
}
|
|
||||||
if (weight_format != Gemma::Config::kBfloat16 &&
|
|
||||||
weight_format != Gemma::Config::kFloat16 &&
|
|
||||||
weight_format != Gemma::Config::kFloat32) {
|
|
||||||
return "Invalid weight format";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Threading-related arguments.
|
|
||||||
struct ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|
||||||
size_t num_threads;
|
|
||||||
Tristate pin_threads;
|
|
||||||
Tristate use_spinning;
|
|
||||||
int verbosity;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(Visitor& visitor) {
|
|
||||||
visitor(num_threads, "threads", size_t{0},
|
|
||||||
"Number of threads (0=auto, half of logical cores)", 1);
|
|
||||||
visitor(pin_threads, "pin_threads", Tristate::kDefault,
|
|
||||||
"Set to true/false to force enable/disable thread pinning.", 2);
|
|
||||||
visitor(use_spinning, "use_spinning", Tristate::kDefault,
|
|
||||||
"Set to true/false to enable/disable thread spinning (typically "
|
|
||||||
"improves "
|
|
||||||
"performance but increases power usage)",
|
|
||||||
2);
|
|
||||||
visitor(verbosity, "verbosity", 1,
|
|
||||||
"Controls printing of progress messages to stderr", 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns nullptr if OK, otherwise error message.
|
|
||||||
const char* Validate() const { return nullptr; }
|
|
||||||
|
|
||||||
// Returns num_threads to use.
|
|
||||||
size_t NumThreadsToUse() const {
|
|
||||||
return num_threads == 0 ? (size_t{hwy::NumberOfProcessors()} + 1) / 2
|
|
||||||
: num_threads;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Command-line arguments for PeftGemma and Gemma.
|
|
||||||
struct GemmaArgs : public ArgsBase<GemmaArgs> {
|
|
||||||
InferenceArgs inference;
|
|
||||||
LoaderArgs loader;
|
|
||||||
ThreadingArgs threading;
|
|
||||||
// For collect_stats.cc:
|
|
||||||
Path output;
|
|
||||||
|
|
||||||
bool trace_outputs; // For -ftrace and dump_csv.cc
|
|
||||||
bool trace_base; // For -ftrace
|
|
||||||
int time_it; // For time_it.cc
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(Visitor& visitor) {
|
|
||||||
inference.ForEach(visitor);
|
|
||||||
loader.ForEach(visitor);
|
|
||||||
threading.ForEach(visitor);
|
|
||||||
|
|
||||||
visitor(output, "output", Path{}, "Where to write CSV data / stats", 2);
|
|
||||||
visitor(trace_outputs, "trace_outputs", false, "For tracing", 2);
|
|
||||||
visitor(trace_base, "trace_base", false, "For tracing", 2);
|
|
||||||
visitor(time_it, "time_it", 0, "For benchmarks", 2);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
57
gemma/run.cc
57
gemma/run.cc
|
|
@ -78,16 +78,15 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
return prompt_string;
|
return prompt_string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// New GetPrompt function that accepts InferenceArgs
|
// Get prompt either from interactive input or command line
|
||||||
std::string GetPrompt(const InferenceArgs& inference, int verbosity,
|
std::string GetPrompt(const InferenceArgs& inference) {
|
||||||
size_t turn) {
|
// If prompt is provided via command line, use that
|
||||||
// Check for command-line prompt first
|
|
||||||
if (!inference.prompt.empty()) {
|
if (!inference.prompt.empty()) {
|
||||||
return inference.prompt;
|
return inference.prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the existing function for interactive mode
|
// Otherwise get interactive prompt
|
||||||
return GetPrompt(std::cin, verbosity, inference.eot_line);
|
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
|
|
@ -101,9 +100,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(inference, gen);
|
InitGenerator(inference, gen);
|
||||||
|
|
||||||
// Add flag to track non-interactive mode
|
|
||||||
bool non_interactive_mode = !inference.prompt.empty();
|
|
||||||
|
|
||||||
const bool have_image = !inference.image_file.path.empty();
|
const bool have_image = !inference.image_file.path.empty();
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens;
|
ImageTokens image_tokens;
|
||||||
|
|
@ -165,13 +161,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
// Read prompt and handle special commands.
|
// Read prompt and handle special commands.
|
||||||
std::string prompt_string =
|
std::string prompt_string = GetPrompt(inference);
|
||||||
GetPrompt(inference, inference.verbosity, abs_pos);
|
|
||||||
|
|
||||||
if (!std::cin && !non_interactive_mode) return;
|
if (!std::cin && inference.prompt.empty()) return;
|
||||||
|
|
||||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
if (!non_interactive_mode && prompt_string.size() >= 2 &&
|
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
|
||||||
prompt_string[0] == '%') {
|
prompt_string[0] == '%') {
|
||||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||||
|
|
@ -180,12 +175,27 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!non_interactive_mode && prompt_string.empty()) {
|
if (inference.prompt.empty() && prompt_string.empty()) {
|
||||||
std::cout << "Use '%q' to quit.\n";
|
std::cout << "Use '%q' to quit.\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
|
size_t prompt_size = 0;
|
||||||
|
size_t prefix_end = 0;
|
||||||
|
if constexpr (kVerboseLogTokens) {
|
||||||
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up runtime config.
|
||||||
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
|
.verbosity = inference.verbosity,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.use_spinning = threading.spin};
|
||||||
|
inference.CopyTo(runtime_config);
|
||||||
|
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
prompt =
|
prompt =
|
||||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||||
|
|
@ -201,21 +211,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up runtime config.
|
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
|
||||||
.verbosity = inference.verbosity,
|
|
||||||
.stream_token = stream_token,
|
|
||||||
.use_spinning = threading.spin};
|
|
||||||
inference.CopyTo(runtime_config);
|
|
||||||
size_t prefix_end = 0;
|
|
||||||
|
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
runtime_config.image_tokens = &image_tokens;
|
runtime_config.image_tokens = &image_tokens;
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
|
|
@ -234,7 +229,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
// Break the loop if in non-interactive mode
|
// Break the loop if in non-interactive mode
|
||||||
if (non_interactive_mode) {
|
if (inference.prompt.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue