diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a808e3e454..cf29bad8ea 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2942,7 +2942,7 @@ llama_context * llama_init_from_model( params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { @@ -2953,7 +2953,7 @@ llama_context * llama_init_from_model( } } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) {