diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0b6cfdcfcd..0ef0bc7458 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8420,9 +8420,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; -#ifdef GGML_SIMD - GGML_ASSERT(DV % GGML_F32_EPR == 0); -#endif int ir = ir0; while (ir < ir1) { @@ -8812,12 +8809,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - const bool use_tiled = !use_ref && + bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_or_f16 && k->type == v->type && neq1 >= Q_TILE_SZ); - +#ifdef GGML_SIMD + use_tiled &= (DV % GGML_F32_EPR == 0); +#endif int current_chunk = ith; while (current_chunk < nchunk) {