Merge pull request #539 from prajwalc22:feature-prompt-flag

PiperOrigin-RevId: 750118715
This commit is contained in:
Copybara-Service 2025-04-22 03:09:19 -07:00
commit f20da328de
7 changed files with 129 additions and 31 deletions

37
.gitattributes vendored Normal file
View File

@ -0,0 +1,37 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text
tokenizer.spm filter=lfs diff=lfs merge=lfs -text

21
.gitignore vendored
View File

@ -1,4 +1,25 @@
# Build directories
.cache/
bazel-*/
build-*/
build/
# Python cache
python/*/__pycache__
# Model files
*.sbs
*.spm
*.data
*.bin
*.weights
# IDE and editor files
.vscode/
.idea/
*.swp
*~
# Local development
.env
.env.local

15
.vscode/c_cpp_properties.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
"configurations": [
{
"name": "Linux",
"includePath": [
"${workspaceFolder}/**"
],
"defines": [],
"cStandard": "c17",
"cppStandard": "c++17",
"intelliSenseMode": "linux-clang-x64"
}
],
"version": 4
}

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.11)
cmake_minimum_required(VERSION 3.11...4.0)
include(FetchContent)

3
build/.gitignore vendored
View File

@ -1,3 +0,0 @@
*
!.gitignore
!.hgignore

View File

@ -28,10 +28,10 @@
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // For CreateGemma
#include "hwy/base.h" // HWY_ABORT
#include "ops/matmul.h"
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ABORT
namespace gcpp {
@ -106,8 +106,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"Path name of model weights (.sbs) file.\n Required argument.\n");
visitor(compressed_weights, "compressed_weights", Path(),
"Deprecated alias for --weights.");
visitor(
model_type_str, "model", std::string(),
visitor(model_type_str, "model", std::string(),
"Model type, see common.cc for valid values.\n");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
@ -117,8 +116,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
const ModelInfo& Info() const { return info_; }
private:
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
// weights file, so we can remove the need for this struct entirely.
ModelInfo info_;
};
@ -161,6 +158,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn;
Path image_file;
std::string prompt; // Added prompt flag for non-interactive mode
std::string eot_line;
// Returns error string or nullptr if OK.
@ -178,7 +176,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
1); // Changed verbosity level to 1 since it's user-facing
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
@ -200,6 +198,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
"generates a response"
" and exits.",
1); // Added as user-facing option
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "

View File

@ -27,13 +27,13 @@
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/gemma_args.h"
#include "hwy/base.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/args.h" // HasHelp
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE
#error "Please update to version 1.2 of github.com/google/highway."
@ -77,6 +77,17 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string;
}
// Get prompt either from interactive input or command line
std::string GetPrompt(const InferenceArgs& inference) {
// If prompt is provided via command line, use that
if (!inference.prompt.empty()) {
return inference.prompt;
}
// Otherwise get interactive prompt
return GetPrompt(std::cin, inference.verbosity, inference.eot_line);
}
// The main Read-Eval-Print Loop.
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
Gemma& model, KVCache& kv_cache) {
@ -149,18 +160,21 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
tokens_generated_this_turn = 0;
// Read prompt and handle special commands.
std::string prompt_string =
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
if (!std::cin) return;
std::string prompt_string = GetPrompt(inference);
if (!std::cin && inference.prompt.empty()) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
if (inference.prompt.empty() && prompt_string.size() >= 2 &&
prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
abs_pos = 0;
continue;
}
}
if (prompt_string.empty()) {
if (inference.prompt.empty() && prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n";
continue;
}
@ -172,9 +186,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if (have_image) {
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
@ -184,8 +198,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
// REMOVED: Don't change prefill_tbatch_size for image handling
// runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string);
@ -206,6 +221,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
timing_info);
std::cout << "\n\n";
// Break the loop if in non-interactive mode
if (!inference.prompt.empty()) {
break;
}
// Prepare for the next turn. Works only for PaliGemma.
if (!inference.multiturn ||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
@ -259,11 +279,14 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader,
instructions += multiturn;
instructions += examples;
// Skip the banner and instructions in non-interactive mode
if (inference.prompt.empty()) {
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(threading, loader, inference);
std::cout << "\n" << instructions << "\n";
}
}
ReplGemma(threading, inference, model, kv_cache);
}
@ -280,6 +303,7 @@ int main(int argc, char** argv) {
if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(threading, loader, inference);
return 0;
}