fix: gate tiled GEMM and split-KV paths to preserve q8_0/q4_0 vec_dot semantics
This commit is contained in:
parent
358bd71b52
commit
23e88631c4
|
|
@ -9171,8 +9171,12 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const bool use_ref = params->use_ref;
|
const bool use_ref = params->use_ref;
|
||||||
|
|
||||||
// Split-KV: parallelize across KV chunks for single-query decode (token generation).
|
// Split-KV: parallelize across KV chunks for single-query decode (token generation).
|
||||||
// Delegates to one_chunk which handles all supported types (F16, Q8_0, Q4_0, MXFP, etc).
|
// Only for types whose tiled/one_chunk paths produce identical results (f32, f16, MXFP).
|
||||||
|
// Standard quant types (q8_0, q4_0) must use the scalar path to preserve vec_dot semantics.
|
||||||
|
const bool kv_is_f32_f16_or_mxfp = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16
|
||||||
|
|| ggml_is_type_mxfp(k->type));
|
||||||
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1)
|
const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1)
|
||||||
|
&& kv_is_f32_f16_or_mxfp
|
||||||
&& q->type == GGML_TYPE_F32 && nek1 >= 512;
|
&& q->type == GGML_TYPE_F32 && nek1 >= 512;
|
||||||
|
|
||||||
if (use_split_kv_path) {
|
if (use_split_kv_path) {
|
||||||
|
|
@ -9230,8 +9234,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||||
|
|
||||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||||
|
// Tiled GEMM path: dequant K/V to float, then simd_gemm.
|
||||||
|
// Only for types that natively dequant to float (f32, f16, MXFP).
|
||||||
|
// Standard quant types (q8_0, q4_0) must use the scalar one_chunk path
|
||||||
|
// to preserve vec_dot semantics and produce identical results to master.
|
||||||
bool use_tiled = !use_ref &&
|
bool use_tiled = !use_ref &&
|
||||||
(q->type == GGML_TYPE_F32 &&
|
(q->type == GGML_TYPE_F32 &&
|
||||||
|
kv_is_f32_f16_or_mxfp &&
|
||||||
|
(k->type == v->type || ggml_is_type_mxfp(k->type)) &&
|
||||||
neq1 >= Q_TILE_SZ);
|
neq1 >= Q_TILE_SZ);
|
||||||
#ifdef GGML_SIMD
|
#ifdef GGML_SIMD
|
||||||
use_tiled &= (DV % GGML_F32_EPR == 0);
|
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue