llama-bench: add --seed and avoid per-token synchronize
This commit is contained in:
parent
388ce82241
commit
16c4aba272
|
|
@ -338,6 +338,7 @@ struct cmd_params {
|
|||
std::vector<bool> embeddings;
|
||||
std::vector<bool> no_op_offload;
|
||||
std::vector<bool> no_host;
|
||||
uint32_t seed;
|
||||
ggml_numa_strategy numa;
|
||||
int reps;
|
||||
ggml_sched_priority prio;
|
||||
|
|
@ -377,6 +378,7 @@ static const cmd_params cmd_params_defaults = {
|
|||
/* embeddings */ { false },
|
||||
/* no_op_offload */ { false },
|
||||
/* no_host */ { false },
|
||||
/* seed */ 1,
|
||||
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
|
||||
/* reps */ 5,
|
||||
/* prio */ GGML_SCHED_PRIO_NORMAL,
|
||||
|
|
@ -408,6 +410,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
|||
printf(" -v, --verbose verbose output\n");
|
||||
printf(" --progress print test progress indicators\n");
|
||||
printf(" --no-warmup skip warmup runs before benchmarking\n");
|
||||
printf(" --seed <n> RNG seed (default: %u)\n", cmd_params_defaults.seed);
|
||||
if (llama_supports_rpc()) {
|
||||
printf(" -rpc, --rpc <rpc_servers> register RPC devices (comma separated)\n");
|
||||
}
|
||||
|
|
@ -513,6 +516,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
params.delay = cmd_params_defaults.delay;
|
||||
params.progress = cmd_params_defaults.progress;
|
||||
params.no_warmup = cmd_params_defaults.no_warmup;
|
||||
params.seed = cmd_params_defaults.seed;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
|
@ -524,6 +528,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv);
|
||||
exit(0);
|
||||
} else if (arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.seed = (uint32_t) std::stoul(argv[i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
|
@ -2003,9 +2013,9 @@ static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
|||
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
}
|
||||
llama_synchronize(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -2053,6 +2063,7 @@ int main(int argc, char ** argv) {
|
|||
ggml_backend_load_all();
|
||||
|
||||
cmd_params params = parse_cmd_params(argc, argv);
|
||||
std::srand((unsigned) params.seed);
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue