fix case where DV % GGML_F32_EPR !=0
This commit is contained in:
parent
a1e1420b46
commit
734f76fbc4
|
|
@ -8420,9 +8420,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||||
|
|
||||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||||
#ifdef GGML_SIMD
|
|
||||||
GGML_ASSERT(DV % GGML_F32_EPR == 0);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int ir = ir0;
|
int ir = ir0;
|
||||||
while (ir < ir1) {
|
while (ir < ir1) {
|
||||||
|
|
@ -8812,12 +8809,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||||
|
|
||||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||||
const bool use_tiled = !use_ref &&
|
bool use_tiled = !use_ref &&
|
||||||
(q->type == GGML_TYPE_F32 &&
|
(q->type == GGML_TYPE_F32 &&
|
||||||
kv_is_f32_or_f16 &&
|
kv_is_f32_or_f16 &&
|
||||||
k->type == v->type &&
|
k->type == v->type &&
|
||||||
neq1 >= Q_TILE_SZ);
|
neq1 >= Q_TILE_SZ);
|
||||||
|
#ifdef GGML_SIMD
|
||||||
|
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||||
|
#endif
|
||||||
int current_chunk = ith;
|
int current_chunk = ith;
|
||||||
|
|
||||||
while (current_chunk < nchunk) {
|
while (current_chunk < nchunk) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue