mirror of https://github.com/google/gemma.cpp.git
Merge pull request #539 from prajwalc22:feature-prompt-flag
PiperOrigin-RevId: 750118715
This commit is contained in:
commit
f20da328de
|
|
@ -0,0 +1,37 @@
|
|||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text
|
||||
tokenizer.spm filter=lfs diff=lfs merge=lfs -text
|
||||
|
|
@ -1,4 +1,25 @@
|
|||
# Build directories
|
||||
.cache/
|
||||
bazel-*/
|
||||
build-*/
|
||||
build/
|
||||
|
||||
# Python cache
|
||||
python/*/__pycache__
|
||||
|
||||
# Model files
|
||||
*.sbs
|
||||
*.spm
|
||||
*.data
|
||||
*.bin
|
||||
*.weights
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# Local development
|
||||
.env
|
||||
.env.local
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Linux",
|
||||
"includePath": [
|
||||
"${workspaceFolder}/**"
|
||||
],
|
||||
"defines": [],
|
||||
"cStandard": "c17",
|
||||
"cppStandard": "c++17",
|
||||
"intelliSenseMode": "linux-clang-x64"
|
||||
}
|
||||
],
|
||||
"version": 4
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
cmake_minimum_required(VERSION 3.11)
|
||||
cmake_minimum_required(VERSION 3.11...4.0)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
*
|
||||
!.gitignore
|
||||
!.hgignore
|
||||
|
|
@ -28,10 +28,10 @@
|
|||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "ops/matmul.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -106,8 +106,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Deprecated alias for --weights.");
|
||||
visitor(
|
||||
model_type_str, "model", std::string(),
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type, see common.cc for valid values.\n");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||
|
|
@ -117,8 +116,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
const ModelInfo& Info() const { return info_; }
|
||||
|
||||
private:
|
||||
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
|
||||
// weights file, so we can remove the need for this struct entirely.
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
|
|
@ -161,6 +158,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
bool multiturn;
|
||||
Path image_file;
|
||||
|
||||
std::string prompt; // Added prompt flag for non-interactive mode
|
||||
std::string eot_line;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
|
|
@ -178,7 +176,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
1); // Changed verbosity level to 1 since it's user-facing
|
||||
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
|
@ -200,6 +198,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Initial prompt for non-interactive mode. When specified, "
|
||||
"generates a response"
|
||||
" and exits.",
|
||||
1); // Added as user-facing option
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
|
|
|
|||
50
gemma/run.cc
50
gemma/run.cc
|
|
@ -27,13 +27,13 @@
|
|||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
||||
#error "Please update to version 1.2 of github.com/google/highway."
|
||||
|
|
@ -77,6 +77,17 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
return prompt_string;
|
||||
}
|
||||
|
||||
// Get prompt either from interactive input or command line
|
||||
std::string GetPrompt(const InferenceArgs& inference) {
|
||||
// If prompt is provided via command line, use that
|
||||
if (!inference.prompt.empty()) {
|
||||
return inference.prompt;
|
||||
}
|
||||
|
||||
// Otherwise get interactive prompt
|
||||
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
Gemma& model, KVCache& kv_cache) {
|
||||
|
|
@ -149,18 +160,21 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
tokens_generated_this_turn = 0;
|
||||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string =
|
||||
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||
if (!std::cin) return;
|
||||
std::string prompt_string = GetPrompt(inference);
|
||||
|
||||
if (!std::cin && inference.prompt.empty()) return;
|
||||
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
|
||||
prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
abs_pos = 0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (prompt_string.empty()) {
|
||||
|
||||
if (inference.prompt.empty() && prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
|
|
@ -172,9 +186,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
std::vector<int> prompt;
|
||||
size_t prompt_size = 0;
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
|
|
@ -184,8 +198,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
|
||||
// REMOVED: Don't change prefill_tbatch_size for image handling
|
||||
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||
} else {
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), abs_pos, prompt_string);
|
||||
|
|
@ -206,6 +221,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
// Break the loop if in non-interactive mode
|
||||
if (!inference.prompt.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!inference.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
|
|
@ -259,11 +279,14 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
|||
instructions += multiturn;
|
||||
instructions += examples;
|
||||
|
||||
// Skip the banner and instructions in non-interactive mode
|
||||
if (inference.prompt.empty()) {
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(threading, loader, inference);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
ReplGemma(threading, inference, model, kv_cache);
|
||||
}
|
||||
|
|
@ -280,6 +303,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue