diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ffa219f167..b49afc94b0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -14,32 +14,6 @@ #include #include -static bool model_supports_compute_type(enum llm_arch arch, ggml_type compute_type) { - // F32 is always supported - it's the default/safe precision - if (compute_type == GGML_TYPE_F32) { - return true; - } - - // Nowadays FP16 and BF16 support is model-specific. - // Add models here as their required ops are 'compute_type' implemented and validated. - // Example (uncomment when ready): - // if (compute_type == GGML_TYPE_F16 || compute_type == GGML_TYPE_BF16) { - // switch (arch) { - // case LLM_ARCH_QWEN2: - // case LLM_ARCH_QWEN2MOE: - // case LLM_ARCH_QWEN3: - // // ... other validated models - // return true; - // default: - // return false; - // } - // } - - // No models enabled yet for non-F32 compute types - (void)arch; - return false; -} - // // llama_context // @@ -192,15 +166,33 @@ llama_context::llama_context( break; } - // check if the model supports the requested compute type - if (cparams.compute_type != GGML_TYPE_F32) { - if (!model_supports_compute_type(model.arch, cparams.compute_type)) { - LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " - "falling back to F32. To enable, the required ops must be implemented first.\n", - __func__, llm_arch_name(model.arch), - ggml_type_name(cparams.compute_type)); - cparams.compute_type = GGML_TYPE_F32; + // Nowadays FP16 and BF16 support is model-specific. + // Add models here as their required ops are 'compute_type' implemented and validated. + auto model_supports_compute_type = [&](ggml_type ct) -> bool { + if (ct == GGML_TYPE_F32) { + return true; // F32 is always supported } + // Example (uncomment when ready): + // if (ct == GGML_TYPE_F16 || ct == GGML_TYPE_BF16) { + // switch (model.arch) { + // case LLM_ARCH_QWEN2: + // case LLM_ARCH_QWEN2MOE: + // case LLM_ARCH_QWEN3: + // return true; + // default: + // return false; + // } + // } + (void)model.arch; // no models enabled yet for non-F32 compute types + return false; + }; + + if (!model_supports_compute_type(cparams.compute_type)) { + LLAMA_LOG_WARN("%s: model arch '%s' does not yet support compute_type %s, " + "falling back to F32. To enable, the required ops must be implemented first.\n", + __func__, llm_arch_name(model.arch), + ggml_type_name(cparams.compute_type)); + cparams.compute_type = GGML_TYPE_F32; } // with causal attention, the batch size is limited by the context size