diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 7da6c3957c..dcd9512241 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -338,6 +338,7 @@ struct cmd_params { std::vector embeddings; std::vector no_op_offload; std::vector no_host; + uint32_t seed; ggml_numa_strategy numa; int reps; ggml_sched_priority prio; @@ -377,6 +378,7 @@ static const cmd_params cmd_params_defaults = { /* embeddings */ { false }, /* no_op_offload */ { false }, /* no_host */ { false }, + /* seed */ 1, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -408,6 +410,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose verbose output\n"); printf(" --progress print test progress indicators\n"); printf(" --no-warmup skip warmup runs before benchmarking\n"); + printf(" --seed RNG seed (default: %u)\n", cmd_params_defaults.seed); if (llama_supports_rpc()) { printf(" -rpc, --rpc register RPC devices (comma separated)\n"); } @@ -513,6 +516,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { params.delay = cmd_params_defaults.delay; params.progress = cmd_params_defaults.progress; params.no_warmup = cmd_params_defaults.no_warmup; + params.seed = cmd_params_defaults.seed; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -524,6 +528,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (arg == "-h" || arg == "--help") { print_usage(argc, argv); exit(0); + } else if (arg == "--seed") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.seed = (uint32_t) std::stoul(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -2003,9 +2013,9 @@ static bool test_gen(llama_context * ctx, int n_gen, int n_threads) { fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res); return false; } - llama_synchronize(ctx); token = std::rand() % n_vocab; } + llama_synchronize(ctx); return true; } @@ -2053,6 +2063,7 @@ int main(int argc, char ** argv) { ggml_backend_load_all(); cmd_params params = parse_cmd_params(argc, argv); + std::srand((unsigned) params.seed); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (!cpu_dev) {