mirror of https://github.com/google/gemma.cpp.git
Merge pull request #539 from prajwalc22:feature-prompt-flag
PiperOrigin-RevId: 750118715
This commit is contained in:
commit
f20da328de
|
|
@ -0,0 +1,37 @@
|
||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text
|
||||||
|
tokenizer.spm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
@ -1,4 +1,25 @@
|
||||||
|
# Build directories
|
||||||
.cache/
|
.cache/
|
||||||
bazel-*/
|
bazel-*/
|
||||||
build-*/
|
build-*/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Python cache
|
||||||
python/*/__pycache__
|
python/*/__pycache__
|
||||||
|
|
||||||
|
# Model files
|
||||||
|
*.sbs
|
||||||
|
*.spm
|
||||||
|
*.data
|
||||||
|
*.bin
|
||||||
|
*.weights
|
||||||
|
|
||||||
|
# IDE and editor files
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Local development
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Linux",
|
||||||
|
"includePath": [
|
||||||
|
"${workspaceFolder}/**"
|
||||||
|
],
|
||||||
|
"defines": [],
|
||||||
|
"cStandard": "c17",
|
||||||
|
"cppStandard": "c++17",
|
||||||
|
"intelliSenseMode": "linux-clang-x64"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"version": 4
|
||||||
|
}
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.11)
|
cmake_minimum_required(VERSION 3.11...4.0)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
*
|
|
||||||
!.gitignore
|
|
||||||
!.hgignore
|
|
||||||
|
|
@ -28,10 +28,10 @@
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -106,8 +106,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
"Path name of model weights (.sbs) file.\n Required argument.\n");
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
visitor(compressed_weights, "compressed_weights", Path(),
|
||||||
"Deprecated alias for --weights.");
|
"Deprecated alias for --weights.");
|
||||||
visitor(
|
visitor(model_type_str, "model", std::string(),
|
||||||
model_type_str, "model", std::string(),
|
|
||||||
"Model type, see common.cc for valid values.\n");
|
"Model type, see common.cc for valid values.\n");
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||||
|
|
@ -117,8 +116,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
const ModelInfo& Info() const { return info_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
|
|
||||||
// weights file, so we can remove the need for this struct entirely.
|
|
||||||
ModelInfo info_;
|
ModelInfo info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -161,6 +158,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
bool multiturn;
|
bool multiturn;
|
||||||
Path image_file;
|
Path image_file;
|
||||||
|
|
||||||
|
std::string prompt; // Added prompt flag for non-interactive mode
|
||||||
std::string eot_line;
|
std::string eot_line;
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
|
|
@ -178,7 +176,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"Show verbose developer information\n 0 = only print generation "
|
"Show verbose developer information\n 0 = only print generation "
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||||
"developer/debug info).\n Default = 1.",
|
"developer/debug info).\n Default = 1.",
|
||||||
2);
|
1); // Changed verbosity level to 1 since it's user-facing
|
||||||
|
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
@ -200,6 +198,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"resets every turn)");
|
"resets every turn)");
|
||||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||||
|
|
||||||
|
visitor(prompt, "prompt", std::string(""),
|
||||||
|
"Initial prompt for non-interactive mode. When specified, "
|
||||||
|
"generates a response"
|
||||||
|
" and exits.",
|
||||||
|
1); // Added as user-facing option
|
||||||
|
|
||||||
visitor(
|
visitor(
|
||||||
eot_line, "eot_line", std::string(""),
|
eot_line, "eot_line", std::string(""),
|
||||||
"End of turn line. "
|
"End of turn line. "
|
||||||
|
|
|
||||||
50
gemma/run.cc
50
gemma/run.cc
|
|
@ -27,13 +27,13 @@
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
#include "gemma/gemma_args.h" // LoaderArgs
|
#include "gemma/gemma_args.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
#include "util/threading_context.h"
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/profiler.h"
|
|
||||||
|
|
||||||
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
|
||||||
#error "Please update to version 1.2 of github.com/google/highway."
|
#error "Please update to version 1.2 of github.com/google/highway."
|
||||||
|
|
@ -77,6 +77,17 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
return prompt_string;
|
return prompt_string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get prompt either from interactive input or command line
|
||||||
|
std::string GetPrompt(const InferenceArgs& inference) {
|
||||||
|
// If prompt is provided via command line, use that
|
||||||
|
if (!inference.prompt.empty()) {
|
||||||
|
return inference.prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise get interactive prompt
|
||||||
|
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||||
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
Gemma& model, KVCache& kv_cache) {
|
Gemma& model, KVCache& kv_cache) {
|
||||||
|
|
@ -149,18 +160,21 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
// Read prompt and handle special commands.
|
// Read prompt and handle special commands.
|
||||||
std::string prompt_string =
|
std::string prompt_string = GetPrompt(inference);
|
||||||
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
|
||||||
if (!std::cin) return;
|
if (!std::cin && inference.prompt.empty()) return;
|
||||||
|
|
||||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
|
||||||
|
prompt_string[0] == '%') {
|
||||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||||
abs_pos = 0;
|
abs_pos = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (prompt_string.empty()) {
|
|
||||||
|
if (inference.prompt.empty() && prompt_string.empty()) {
|
||||||
std::cout << "Use '%q' to quit.\n";
|
std::cout << "Use '%q' to quit.\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -172,9 +186,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
|
size_t prompt_size = 0;
|
||||||
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
prompt =
|
prompt =
|
||||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||||
|
|
@ -184,8 +198,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
prefix_end = prompt_size;
|
prefix_end = prompt_size;
|
||||||
// We need to look at all the tokens for the prefix.
|
|
||||||
runtime_config.prefill_tbatch_size = prompt_size;
|
// REMOVED: Don't change prefill_tbatch_size for image handling
|
||||||
|
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
} else {
|
} else {
|
||||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||||
model.Info(), abs_pos, prompt_string);
|
model.Info(), abs_pos, prompt_string);
|
||||||
|
|
@ -206,6 +221,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
timing_info);
|
timing_info);
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
|
// Break the loop if in non-interactive mode
|
||||||
|
if (!inference.prompt.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare for the next turn. Works only for PaliGemma.
|
// Prepare for the next turn. Works only for PaliGemma.
|
||||||
if (!inference.multiturn ||
|
if (!inference.multiturn ||
|
||||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
|
|
@ -259,11 +279,14 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
instructions += multiturn;
|
instructions += multiturn;
|
||||||
instructions += examples;
|
instructions += examples;
|
||||||
|
|
||||||
|
// Skip the banner and instructions in non-interactive mode
|
||||||
|
if (inference.prompt.empty()) {
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(threading, loader, inference);
|
ShowConfig(threading, loader, inference);
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ReplGemma(threading, inference, model, kv_cache);
|
ReplGemma(threading, inference, model, kv_cache);
|
||||||
}
|
}
|
||||||
|
|
@ -280,6 +303,7 @@ int main(int argc, char** argv) {
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cerr << gcpp::kAsciiArtBanner;
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
|
|
||||||
gcpp::ShowHelp(threading, loader, inference);
|
gcpp::ShowHelp(threading, loader, inference);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue