diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ae66414 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +2b-pt-sfp.sbs filter=lfs diff=lfs merge=lfs -text +tokenizer.spm filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index d4264cb..1c13032 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,25 @@ +# Build directories .cache/ bazel-*/ build-*/ +build/ + +# Python cache python/*/__pycache__ + +# Model files +*.sbs +*.spm +*.data +*.bin +*.weights + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*~ + +# Local development +.env +.env.local \ No newline at end of file diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..64d3f90 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,15 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [], + "cStandard": "c17", + "cppStandard": "c++17", + "intelliSenseMode": "linux-clang-x64" + } + ], + "version": 4 +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index b572835..b9558ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.11) +cmake_minimum_required(VERSION 3.11...4.0) include(FetchContent) diff --git a/build/.gitignore b/build/.gitignore deleted file mode 100644 index 3822a0b..0000000 --- a/build/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -* -!.gitignore -!.hgignore \ No newline at end of file diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 4fe2d33..d02dece 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -28,10 +28,10 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma +#include "hwy/base.h" // HWY_ABORT #include "ops/matmul.h" #include "util/args.h" #include "util/basics.h" // Tristate -#include "hwy/base.h" // HWY_ABORT namespace gcpp { @@ -106,9 +106,8 @@ struct LoaderArgs : public ArgsBase { "Path name of model weights (.sbs) file.\n Required argument.\n"); visitor(compressed_weights, "compressed_weights", Path(), "Deprecated alias for --weights."); - visitor( - model_type_str, "model", std::string(), - "Model type, see common.cc for valid values.\n"); + visitor(model_type_str, "model", std::string(), + "Model type, see common.cc for valid values.\n"); visitor(weight_type_str, "weight_type", std::string("sfp"), "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); } @@ -117,8 +116,6 @@ struct LoaderArgs : public ArgsBase { const ModelInfo& Info() const { return info_; } private: - // TODO(rays): remove this. Eventually ModelConfig will be loaded from the - // weights file, so we can remove the need for this struct entirely. ModelInfo info_; }; @@ -161,6 +158,7 @@ struct InferenceArgs : public ArgsBase { bool multiturn; Path image_file; + std::string prompt; // Added prompt flag for non-interactive mode std::string eot_line; // Returns error string or nullptr if OK. @@ -178,7 +176,7 @@ struct InferenceArgs : public ArgsBase { "Show verbose developer information\n 0 = only print generation " "output\n 1 = standard user-facing terminal ui\n 2 = show " "developer/debug info).\n Default = 1.", - 2); + 1); // Changed verbosity level to 1 since it's user-facing visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); @@ -200,6 +198,12 @@ struct InferenceArgs : public ArgsBase { "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); + visitor(prompt, "prompt", std::string(""), + "Initial prompt for non-interactive mode. When specified, " + "generates a response" + " and exits.", + 1); // Added as user-facing option + visitor( eot_line, "eot_line", std::string(""), "End of turn line. " @@ -233,4 +237,4 @@ struct InferenceArgs : public ArgsBase { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 5170b6e..2de1c1d 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -27,13 +27,13 @@ #include "evals/benchmark_helper.h" #include "gemma/common.h" #include "gemma/gemma.h" // Gemma -#include "gemma/gemma_args.h" // LoaderArgs -#include "ops/matmul.h" // MatMulEnv -#include "paligemma/image.h" -#include "util/args.h" // HasHelp -#include "util/threading_context.h" +#include "gemma/gemma_args.h" +#include "hwy/base.h" #include "hwy/highway.h" #include "hwy/profiler.h" +#include "ops/matmul.h" // MatMulEnv +#include "paligemma/image.h" +#include "util/args.h" // HasHelp #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #error "Please update to version 1.2 of github.com/google/highway." @@ -77,6 +77,17 @@ std::string GetPrompt(std::istream& input, int verbosity, return prompt_string; } +// Get prompt either from interactive input or command line +std::string GetPrompt(const InferenceArgs& inference) { + // If prompt is provided via command line, use that + if (!inference.prompt.empty()) { + return inference.prompt; + } + + // Otherwise get interactive prompt + return GetPrompt(std::cin, inference.verbosity, inference.eot_line); +} + // The main Read-Eval-Print Loop. void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, Gemma& model, KVCache& kv_cache) { @@ -149,18 +160,21 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, tokens_generated_this_turn = 0; // Read prompt and handle special commands. - std::string prompt_string = - GetPrompt(std::cin, inference.verbosity, inference.eot_line); - if (!std::cin) return; + std::string prompt_string = GetPrompt(inference); + + if (!std::cin && inference.prompt.empty()) return; + // If !eot_line.empty(), we append \n, so only look at the first 2 chars. - if (prompt_string.size() >= 2 && prompt_string[0] == '%') { + if (inference.prompt.empty() && prompt_string.size() >= 2 && + prompt_string[0] == '%') { if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { abs_pos = 0; continue; } } - if (prompt_string.empty()) { + + if (inference.prompt.empty() && prompt_string.empty()) { std::cout << "Use '%q' to quit.\n"; continue; } @@ -172,9 +186,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .stream_token = stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); - size_t prefix_end = 0; - std::vector prompt; + size_t prompt_size = 0; + size_t prefix_end = 0; if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), @@ -184,8 +198,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; + + // REMOVED: Don't change prefill_tbatch_size for image handling + // runtime_config.prefill_tbatch_size = prompt_size; } else { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string); @@ -206,6 +221,11 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, timing_info); std::cout << "\n\n"; + // Break the loop if in non-interactive mode + if (!inference.prompt.empty()) { + break; + } + // Prepare for the next turn. Works only for PaliGemma. if (!inference.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { @@ -259,10 +279,13 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader, instructions += multiturn; instructions += examples; - std::cout << "\033[2J\033[1;1H" // clear screen - << kAsciiArtBanner << "\n\n"; - ShowConfig(threading, loader, inference); - std::cout << "\n" << instructions << "\n"; + // Skip the banner and instructions in non-interactive mode + if (inference.prompt.empty()) { + std::cout << "\033[2J\033[1;1H" // clear screen + << kAsciiArtBanner << "\n\n"; + ShowConfig(threading, loader, inference); + std::cout << "\n" << instructions << "\n"; + } } ReplGemma(threading, inference, model, kv_cache); @@ -280,6 +303,7 @@ int main(int argc, char** argv) { if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(threading, loader, inference); return 0; } @@ -300,4 +324,4 @@ int main(int argc, char** argv) { } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; -} +} \ No newline at end of file