diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index df17cc5530..7486acc2b5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d950972c83..765ce07f06 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data;