diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 4fe2d33..dc4019c 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -28,13 +28,205 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma +#include "hwy/base.h" // HWY_IS_ASAN, HWY_ABORT #include "ops/matmul.h" +#include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate -#include "hwy/base.h" // HWY_ABORT +#include "util/threading.h" +#include "util/threading_context.h" namespace gcpp { +static inline const char* CompiledConfig() { + if (HWY_IS_ASAN) { + return "asan"; + } else if (HWY_IS_MSAN) { + return "msan"; + } else if (HWY_IS_TSAN) { + return "tsan"; + } else if (HWY_IS_HWASAN) { + return "hwasan"; + } else if (HWY_IS_UBSAN) { + return "ubsan"; + } else if (HWY_IS_DEBUG_BUILD) { + return "dbg"; + } else { + return "opt"; + } +} +template +struct ArgsBase { + void Init() { static_cast(this)->ForEach(SetToDefault()); } + + void InitAndParse(int argc, char* argv[]) { + Init(); + static_cast(this)->ForEach(ParseOption(argc, argv)); + } + + void Print(int min_verbosity = 1) const { + static_cast(this)->ForEach(PrintOption(min_verbosity)); + } + + void Help() const { static_cast(this)->ForEach(PrintHelp()); } + + protected: + // Helper struct for printing help messages + struct PrintHelp { + template + void operator()(const T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + // Special case for strings to avoid template deduction issues + void operator()(const std::string& value, const char* name, + const std::string& default_value, const char* description, + int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + // Special case for Path type + void operator()(const Path& value, const char* name, + const Path& default_value, const char* description, + int verbosity = 1) const { + fprintf(stderr, " --%s\n %s\n", name, description); + } + }; + + // Helper struct for setting default values + struct SetToDefault { + template + void operator()(T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + value = default_value; + } + }; + + // Helper struct for printing values + struct PrintOption { + explicit PrintOption(int min_verbosity) : min_verbosity_(min_verbosity) {} + + template + void operator()(const T& value, const char* name, const T& default_value, + const char* description, int verbosity = 1) const { + if (verbosity >= min_verbosity_) { + fprintf(stderr, "%s: %s\n", name, ToString(value).c_str()); + } + } + + private: + int min_verbosity_; + + // Helper function to convert values to string + template + static std::string ToString(const T& value) { + return std::to_string(value); + } + // Specialization for string + static std::string ToString(const std::string& value) { return value; } + // Specialization for Path + static std::string ToString(const Path& value) { return value.path; } + }; +}; +struct ThreadingArgs : public ArgsBase { + public: + ThreadingArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + ThreadingArgs() { Init(); }; + + int verbosity; + + size_t max_threads; // divided among the detected clusters + Tristate pin; // pin threads? + Tristate spin; // use spin waits? + + // For BoundedSlice: + size_t skip_packages; + size_t max_packages; + size_t skip_clusters; + size_t max_clusters; + size_t skip_lps; + size_t max_lps; + + std::string eot_line; + std::string prompt; + template + void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 2); + + // The exact meaning is more subtle: see the comment at NestedPools ctor. + visitor(max_threads, "num_threads", size_t{0}, + "Maximum number of threads to use; default 0 = unlimited.", 2); + visitor(pin, "pin", Tristate::kDefault, + "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(spin, "spin", Tristate::kDefault, + "Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2); + // These can be used to partition CPU sockets/packages and their + // clusters/CCXs across several program instances. The default is to use + // all available resources. + visitor(skip_packages, "skip_packages", size_t{0}, + "Index of the first socket to use; default 0 = unlimited.", 2); + visitor(max_packages, "max_packages", size_t{0}, + "Maximum number of sockets to use; default 0 = unlimited.", 2); + visitor(skip_clusters, "skip_clusters", size_t{0}, + "Index of the first CCX to use; default 0 = unlimited.", 2); + visitor(max_clusters, "max_clusters", size_t{0}, + "Maximum number of CCXs to use; default 0 = unlimited.", 2); + // These are only used when CPU topology is unknown. + visitor(skip_lps, "skip_lps", size_t{0}, + "Index of the first LP to use; default 0 = unlimited.", 2); + visitor(max_lps, "max_lps", size_t{0}, + "Maximum number of LPs to use; default 0 = unlimited.", 2); + + visitor( + eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.\n Default = " + "When a newline is encountered, that signals the end of the turn.", + 2); + + visitor(prompt, "prompt", std::string(""), + "Prompt string for non-interactive mode. When provided, the model " + "generates a response and exits.", + 2); + } +}; +static inline BoundedTopology CreateTopology(const ThreadingArgs& threading) { + return BoundedTopology( + BoundedSlice(threading.skip_packages, threading.max_packages), + BoundedSlice(threading.skip_clusters, threading.max_clusters), + BoundedSlice(threading.skip_lps, threading.max_lps)); +} + +static inline MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading) { + ThreadingContext2::SetArgs(threading); + return MatMulEnv(ThreadingContext2::Get()); +} +// Note: These functions may need adjustments depending on your specific class +// definitions +static inline BoundedTopology CreateTopology(const ThreadingArgs& app) { + return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages), + BoundedSlice(app.skip_clusters, app.max_clusters), + BoundedSlice(app.skip_lps, app.max_lps)); +} + +// This function may need to be adjusted based on your NestedPools constructor +// signature +static inline NestedPools CreatePools(const BoundedTopology& topology, + const ThreadingArgs& threading) { + // Make sure Allocator::Init() is properly declared/defined + const Allocator2& allocator = ThreadingContext2::Get().allocator; + // Allocator::Init(topology); + + // Adjust the constructor call based on your actual NestedPools constructor + // The error suggests that the constructor doesn't match these arguments + return NestedPools(topology, allocator, threading.max_threads, threading.pin); + // Alternative: return NestedPools(topology, app.max_threads, app.pin); +} + struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[], bool validate = true) { InitAndParse(argc, argv); @@ -106,9 +298,8 @@ struct LoaderArgs : public ArgsBase { "Path name of model weights (.sbs) file.\n Required argument.\n"); visitor(compressed_weights, "compressed_weights", Path(), "Deprecated alias for --weights."); - visitor( - model_type_str, "model", std::string(), - "Model type, see common.cc for valid values.\n"); + visitor(model_type_str, "model", std::string(), + "Model type, see common.cc for valid values.\n"); visitor(weight_type_str, "weight_type", std::string("sfp"), "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP."); } @@ -231,6 +422,22 @@ struct InferenceArgs : public ArgsBase { } }; +static inline void ShowConfig(const ThreadingArgs& threading, + const LoaderArgs& loader, + const InferenceArgs& inference) { + threading.Print(); + loader.Print(); + inference.Print(); +} +static inline void ShowHelp(const ThreadingArgs& threading, + const LoaderArgs& loader, + const InferenceArgs& inference) { + fprintf(stderr, "\nUsage: gemma [flags]\n\nFlags:\n"); + threading.Help(); + loader.Help(); + inference.Help(); +} + } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 5170b6e..381dac4 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -27,13 +27,14 @@ #include "evals/benchmark_helper.h" #include "gemma/common.h" #include "gemma/gemma.h" // Gemma -#include "gemma/gemma_args.h" // LoaderArgs -#include "ops/matmul.h" // MatMulEnv -#include "paligemma/image.h" -#include "util/args.h" // HasHelp -#include "util/threading_context.h" +#include "gemma/gemma_args.h" +#include "hwy/base.h" #include "hwy/highway.h" #include "hwy/profiler.h" +#include "ops/matmul.h" // MatMulEnv +#include "paligemma/image.h" +#include "util/args.h" // HasHelp +#include "util/threading.h" #if (!defined(HWY_VERSION_LT) || HWY_VERSION_LT(1, 2)) && !HWY_IDE #error "Please update to version 1.2 of github.com/google/highway." @@ -165,6 +166,16 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, continue; } + // Wrap, tokenize and maybe log prompt tokens. + std::vector prompt = WrapAndTokenize(model.Tokenizer(), model.Info(), + abs_pos, prompt_string); + prompt_size = prompt.size(); + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } + } + // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, @@ -238,6 +249,22 @@ void Run(ThreadingArgs& threading, LoaderArgs& loader, KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); + if (!threading.prompt.empty()) { + std::vector prompt = + WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), + 0, threading.prompt); + + TimingInfo timing_info = {.verbosity = inference.verbosity}; + RuntimeConfig runtime_config = {.gen = nullptr, // Use default generator + .verbosity = inference.verbosity, + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); + + model.Generate(runtime_config, prompt, 0, 0, kv_cache, timing_info); + std::cout << "\n"; + return; // Exit after generating response + } + if (inference.verbosity >= 1) { std::string instructions = "*Usage*\n" @@ -280,6 +307,7 @@ int main(int argc, char** argv) { if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; + gcpp::ShowHelp(threading, loader, inference); return 0; }