context : add compute-type to set intermediate computation precision
This commit is contained in:
parent
3136a849db
commit
e30dc63cb6
|
|
@ -1332,6 +1332,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg({ "-ct", "--compute-type" }, "[f32|f16|bf16|default]",
|
||||
string_format("set intermediate computation precision ('f32', 'f16', 'bf16' or 'default', default: '%s')",
|
||||
llama_compute_type_name(params.compute_type)),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "f32") {
|
||||
params.compute_type = LLAMA_COMPUTE_TYPE_F32;
|
||||
} else if (value == "f16") {
|
||||
params.compute_type = LLAMA_COMPUTE_TYPE_F16;
|
||||
} else if (value == "bf16") {
|
||||
params.compute_type = LLAMA_COMPUTE_TYPE_BF16;
|
||||
} else if (value == "default") {
|
||||
params.compute_type = LLAMA_COMPUTE_TYPE_DEFAULT;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
string_format("error: unknown value for --compute-type: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_ARG_COMPUTE_TYPE"));
|
||||
add_opt(common_arg(
|
||||
{"-p", "--prompt"}, "PROMPT",
|
||||
"prompt to start generation with; for system message, use -sys",
|
||||
|
|
|
|||
|
|
@ -1412,6 +1412,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.attention_type = params.attention_type;
|
||||
cparams.flash_attn_type = params.flash_attn_type;
|
||||
cparams.compute_type = params.compute_type;
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
|
|
|||
|
|
@ -402,6 +402,7 @@ struct common_params {
|
|||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
|
||||
enum llama_compute_type compute_type = LLAMA_COMPUTE_TYPE_DEFAULT; // intermediate computation precision
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
|
|
|||
|
|
@ -188,6 +188,15 @@ extern "C" {
|
|||
|
||||
LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
|
||||
|
||||
enum llama_compute_type {
|
||||
LLAMA_COMPUTE_TYPE_DEFAULT = 0, // no override, use model's native precision.
|
||||
LLAMA_COMPUTE_TYPE_F32 = 1,
|
||||
LLAMA_COMPUTE_TYPE_F16 = 2,
|
||||
LLAMA_COMPUTE_TYPE_BF16 = 3,
|
||||
};
|
||||
|
||||
LLAMA_API const char * llama_compute_type_name(enum llama_compute_type compute_type);
|
||||
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
|
|
@ -336,6 +345,7 @@ extern "C" {
|
|||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||
enum llama_attention_type attention_type; // attention type to use for embeddings
|
||||
enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
|
||||
enum llama_compute_type compute_type; // intermediate activation precision [EXPERIMENTAL]
|
||||
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
|
|
|
|||
|
|
@ -148,6 +148,24 @@ llama_context::llama_context(
|
|||
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
|
||||
// convert public enum → internal ggml_type for intermediate computation precision
|
||||
switch (params.compute_type) {
|
||||
case LLAMA_COMPUTE_TYPE_F32:
|
||||
cparams.compute_type = GGML_TYPE_F32;
|
||||
break;
|
||||
case LLAMA_COMPUTE_TYPE_F16:
|
||||
cparams.compute_type = GGML_TYPE_F16;
|
||||
break;
|
||||
case LLAMA_COMPUTE_TYPE_BF16:
|
||||
cparams.compute_type = GGML_TYPE_BF16;
|
||||
break;
|
||||
case LLAMA_COMPUTE_TYPE_DEFAULT:
|
||||
default:
|
||||
// DEFAULT = no override, use model's native precision (F32 for now)
|
||||
cparams.compute_type = GGML_TYPE_F32;
|
||||
break;
|
||||
}
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
||||
|
|
@ -194,6 +212,7 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
|
||||
LLAMA_LOG_INFO("%s: compute_type = %s\n", __func__, llama_compute_type_name(params.compute_type));
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
|
@ -2951,6 +2970,7 @@ llama_context_params llama_context_default_params() {
|
|||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
||||
/*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
/*.compute_type =*/ LLAMA_COMPUTE_TYPE_DEFAULT,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ -1.0f,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ struct llama_cparams {
|
|||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_type compute_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -862,6 +862,24 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_cast_to_compute_type(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const {
|
||||
if (cparams.compute_type == GGML_TYPE_F32 || cur->type == cparams.compute_type) {
|
||||
return cur;
|
||||
}
|
||||
return ggml_cast(ctx, cur, cparams.compute_type);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_cast_to_f32(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const {
|
||||
if (cur->type == GGML_TYPE_F32) {
|
||||
return cur;
|
||||
}
|
||||
return ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_cvec(
|
||||
ggml_tensor * cur,
|
||||
int il) const {
|
||||
|
|
|
|||
|
|
@ -756,6 +756,14 @@ struct llm_graph_context {
|
|||
|
||||
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||
|
||||
ggml_tensor * build_cast_to_compute_type( // intermediate computation precision.
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const;
|
||||
|
||||
ggml_tensor * build_cast_to_f32(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur) const;
|
||||
|
||||
//
|
||||
// common
|
||||
//
|
||||
|
|
|
|||
|
|
@ -43,6 +43,20 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
const char * llama_compute_type_name(enum llama_compute_type compute_type) {
|
||||
switch (compute_type) {
|
||||
case LLAMA_COMPUTE_TYPE_DEFAULT:
|
||||
return "default";
|
||||
case LLAMA_COMPUTE_TYPE_F32:
|
||||
return "f32";
|
||||
case LLAMA_COMPUTE_TYPE_F16:
|
||||
return "f16";
|
||||
case LLAMA_COMPUTE_TYPE_BF16:
|
||||
return "bf16";
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
struct llama_device_memory_data {
|
||||
int64_t total;
|
||||
int64_t free;
|
||||
|
|
|
|||
Loading…
Reference in New Issue