diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 20c912d0e9..bd3f4a487c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2437,7 +2437,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * state); + struct ggml_tensor * state, + bool fuse_exp); /* Solves a specific equation of the form Ax=B, where A is a triangular matrix * without zeroes on the diagonal (i.e. invertible). diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971..f4741da9d7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9875,9 +9875,9 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s static void ggml_compute_forward_rwkv_wkv7_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const int64_t T = dst->src[1]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t HEADS = dst->src[4]->ne[1]; const int64_t n_seqs = dst->src[6]->ne[1]; const int64_t head_size = C / HEADS; @@ -9902,6 +9902,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( float * a = (float *) dst->src[4]->data; float * b = (float *) dst->src[5]->data; + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; + constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + int64_t t_stride = HEADS * head_size; // Same to C int64_t h_stride = C / HEADS; @@ -9938,7 +9941,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v_val * k_val; @@ -9997,6 +10000,11 @@ static void ggml_compute_forward_rwkv_wkv7_f32( GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + if (fuse_exp) { + w_vec = ggml_v_sigmoid(w_vec, w_scale); + w_vec = ggml_v_expf(w_vec); + } + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); @@ -10016,7 +10024,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v[t_h_i_offset] * k_val; @@ -10057,7 +10065,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v_val * k_val; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b50..5f22b4d489 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1129,16 +1129,27 @@ inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) { svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { +// computes (1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_one_plus_exp_neg_x(svbool_t pg, svfloat32_t x) { const svfloat32_t one = svdup_n_f32_x(pg, 1.0f); const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f); const svfloat32_t neg_x = svsub_f32_x(pg, zero, x); const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x); - const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x); + return svadd_f32_x(pg, one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { + const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x); return svdiv_f32_x(pg, x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static svfloat32_t ggml_v_sigmoid(svbool_t pg, svfloat32_t x, float scale) { + const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x); + return svdiv_f32_x(pg, svdup_n_f32_x(pg, scale), one_plus_exp_neg_x); +} + #elif defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine @@ -1168,16 +1179,27 @@ inline static float32x4_t ggml_v_expf(float32x4_t x) { vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static float32x4_t ggml_v_silu(float32x4_t x) { +// computes (1+exp(-x)) in single precision vector +inline static float32x4_t ggml_v_one_plus_exp_neg_x(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); const float32x4_t neg_x = vsubq_f32(zero, x); const float32x4_t exp_neg_x = ggml_v_expf(neg_x); - const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vaddq_f32(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static float32x4_t ggml_v_silu(float32x4_t x) { + const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return vdivq_f32(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static float32x4_t ggml_v_sigmoid(float32x4_t x, float scale) { + const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return vdivq_f32(vdupq_n_f32(scale), one_plus_exp_neg_x); +} + #elif defined(__AVX512F__) && defined(__AVX512DQ__) // adapted from arm limited optimized routine @@ -1211,16 +1233,27 @@ inline static __m512 ggml_v_expf(__m512 x) { return _mm512_mask_blend_ps(d, res, alt); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m512 ggml_v_silu(__m512 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m512 ggml_v_one_plus_exp_neg_x(__m512 x) { const __m512 one = _mm512_set1_ps(1); const __m512 zero = _mm512_setzero_ps(); const __m512 neg_x = _mm512_sub_ps(zero, x); const __m512 exp_neg_x = ggml_v_expf(neg_x); - const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m512 ggml_v_silu(__m512 x) { + const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm512_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m512 ggml_v_sigmoid(__m512 x, float scale) { + const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm512_div_ps(_mm512_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__AVX2__) && defined(__FMA__) // adapted from arm limited optimized routine @@ -1266,16 +1299,27 @@ inline static __m256 ggml_v_expf(__m256 x) { _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m256 ggml_v_silu(__m256 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m256 ggml_v_one_plus_exp_neg_x(__m256 x) { const __m256 one = _mm256_set1_ps(1); const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = ggml_v_expf(neg_x); - const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m256 ggml_v_silu(__m256 x) { + const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm256_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m256 ggml_v_sigmoid(__m256 x, float scale) { + const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm256_div_ps(_mm256_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON #if defined(__FMA__) @@ -1320,16 +1364,27 @@ inline static __m128 ggml_v_expf(__m128 x) { _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m128 ggml_v_silu(__m128 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m128 ggml_v_one_plus_exp_neg_x(__m128 x) { const __m128 one = _mm_set1_ps(1); const __m128 zero = _mm_setzero_ps(); const __m128 neg_x = _mm_sub_ps(zero, x); const __m128 exp_neg_x = ggml_v_expf(neg_x); - const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + return _mm_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m128 ggml_v_silu(__m128 x) { + const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m128 ggml_v_sigmoid(__m128 x, float scale) { + const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm_div_ps(_mm_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__riscv_v_intrinsic) // adapted from arm limited optimized routine @@ -1374,14 +1429,25 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { vl); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { +// computes (1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_one_plus_exp_neg_x_m2(vfloat32m2_t x, int vl) { const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl); const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl); - const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); + return __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl); return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl); } +// computes sigmoid 1/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_sigmoid_m2(vfloat32m2_t x, int vl, float scale) { + const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl); + return __riscv_vfrdiv_vf_f32m2(one_plus_exp_neg_x, scale, vl); +} + #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { diff --git a/ggml/src/ggml-cuda/wkv.cu b/ggml/src/ggml-cuda/wkv.cu index d2fced705e..67e4ae9221 100644 --- a/ggml/src/ggml-cuda/wkv.cu +++ b/ggml/src/ggml-cuda/wkv.cu @@ -65,7 +65,9 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const } } -template +constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + +template static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) { const int tid = threadIdx.x; const int bid = blockIdx.x; @@ -89,7 +91,7 @@ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, cons for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { __syncthreads(); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? __expf(w_scale / (1.0f + __expf(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; @@ -179,9 +181,11 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const float * s_d = (const float *)dst->src[6]->data; const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; + const int64_t H = dst->src[4]->ne[1]; + + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; float * dst_d = (float *)dst->data; @@ -192,8 +196,16 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); if (C / H == CUDA_WKV_BLOCK_SIZE) { - rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + if (fuse_exp) { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } } else { - rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + if (fuse_exp) { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a50b12b6f3..3d7cbf90c7 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1519,6 +1519,8 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + const int32_t fuse_exp = ((const int32_t *) op->op_params)[0]; + const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; const int64_t T = op->src[0]->ne[2]; const int64_t C = op->ne[0]; @@ -1543,6 +1545,9 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_bytes (enc, (void *) &fuse_exp, sizeof(fuse_exp), ida++); + } ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 67b30e0d93..5df81b5e86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2657,6 +2657,7 @@ kernel void kernel_rwkv_wkv7_f32( constant uint & T, constant uint & C, constant uint & H, + constant uint & fuse_exp, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2666,6 +2667,8 @@ kernel void kernel_rwkv_wkv7_f32( const uint head_id = tgpig.x % H; const uint tid = tpitg.x; + constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + if (batch_id >= B || head_id >= H) { return; } @@ -2692,7 +2695,7 @@ kernel void kernel_rwkv_wkv7_f32( for (uint t = start_t; t < end_t; t += C) { threadgroup_barrier(mem_flags::mem_threadgroup); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? exp(w_scale / (1 + exp(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index c10e2f7645..f8ead16d56 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -96,7 +96,9 @@ static void rwkv_wkv6_f32_kernel( } } -template +constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + +template static void rwkv_wkv7_f32_kernel( const int B, const int T, const int C, const int H, const float* r, const float* w, const float* k, const float* v, @@ -132,7 +134,7 @@ static void rwkv_wkv7_f32_kernel( item_ct1.barrier(sycl::access::fence_space::local_space); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? sycl::native::exp(w_scale / (1.0f + sycl::native::exp(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; @@ -247,9 +249,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float* dst_d = (float*)dst->data; const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; + const int64_t H = dst->src[4]->ne[1]; + + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); GGML_ASSERT(C % H == 0); @@ -264,30 +268,61 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Submit kernel if (C / H == WKV_BLOCK_SIZE) { - stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + if (fuse_exp) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rwkv_wkv7_f32_kernel( - B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, - item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() - ); + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); }); - }); + }); + } } else { - stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + if (fuse_exp) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rwkv_wkv7_f32_kernel( - B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, - item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() - ); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d68735a040..9cd61dc817 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -780,6 +780,7 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_rwkv_wkv7_f32_fuse_exp; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -4299,6 +4300,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32_fuse_exp, "rwkv_wkv7_f32_fuse_exp", rwkv_wkv7_f32_fuse_exp_len, rwkv_wkv7_f32_fuse_exp_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); @@ -8992,6 +8994,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RWKV_WKV7: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (dst->op_params[0] == 1) { + return ctx->device->pipeline_rwkv_wkv7_f32_fuse_exp; + } return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; @@ -9748,10 +9753,62 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co }); } -static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) { - GGML_ASSERT(version == 6 || version == 7); - int num_srcs = version == 6 ? 6 : 7; +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + const vk_op_rwkv_wkv6_push_constants pc = { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }; + + const int num_srcs = 6; + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < num_srcs; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + const vk_op_rwkv_wkv7_push_constants pc = { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }; + + const int num_srcs = 7; for (int i = 0; i < num_srcs; i++) { GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); } @@ -9775,54 +9832,9 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx 1 }; - if (version == 6) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, elements); - } else if (version == 7) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, - pc, elements); - } else { - // shouldn't happen - GGML_ASSERT(false); - } -} - -static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const size_t seq_length = dst->src[0]->ne[2]; - const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[1]; - const size_t n_seqs = dst->src[5]->ne[1]; - - ggml_vk_op_f32_wkv( - ctx, subctx, dst, - { - (uint32_t)n_seqs, - (uint32_t)seq_length, - (uint32_t)n_embed, - (uint32_t)n_heads, - }, - 6 - ); -} - -static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const size_t seq_length = dst->src[0]->ne[2]; - const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[1]; - const size_t n_seqs = dst->src[6]->ne[1]; - - ggml_vk_op_f32_wkv( - ctx, subctx, dst, - { - (uint32_t)n_seqs, - (uint32_t)seq_length, - (uint32_t)n_embed, - (uint32_t)n_heads, - }, - 7 - ); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { @@ -15558,7 +15570,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], - src_clone[4], src_clone[5], src_clone[6]); + src_clone[4], src_clone[5], src_clone[6], (uint32_t) tensor->op_params[0]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bbdbf9dcaa..2e7550912e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -967,7 +967,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "0"}})); + string_to_spv("rwkv_wkv7_f32_fuse_exp", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "1"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp index 88c1c02b32..fafa1346fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -10,6 +10,7 @@ layout(push_constant) uniform Parameters { uint T; uint C; uint H; + uint fuse_exp; }; layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; @@ -45,10 +46,18 @@ void main() { const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; +#if FUSE_EXP + const float w_scale = -0.6065306597f; // -exp(-0.5) +#endif + for (uint t = start_t; t < end_t; t += C) { barrier(); _r[tid] = r[t]; +#if FUSE_EXP + _w[tid] = exp(w_scale / (1 + exp(-w[t]))); +#else _w[tid] = w[t]; +#endif _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 09b8eb466d..aa1188d44c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5698,7 +5698,8 @@ struct ggml_tensor * ggml_rwkv_wkv7( struct ggml_tensor * v, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * state) { + struct ggml_tensor * state, + bool fuse_exp) { GGML_ASSERT(ggml_is_contiguous(r)); GGML_ASSERT(ggml_is_contiguous(w)); GGML_ASSERT(ggml_is_contiguous(k)); @@ -5707,14 +5708,16 @@ struct ggml_tensor * ggml_rwkv_wkv7( GGML_ASSERT(ggml_is_contiguous(b)); GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S = k->ne[0]; - const int64_t H = k->ne[1]; - const int64_t n_tokens = k->ne[2]; + const int64_t S = a->ne[0]; + const int64_t H = a->ne[1]; + const int64_t n_tokens = a->ne[2]; const int64_t n_seqs = state->ne[1]; + const int64_t n_embd = S * H; { - GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens); + GGML_ASSERT(w->ne[0] == n_embd && w->ne[1] == n_tokens); GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens); - GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(v->ne[0] == n_embd && v->ne[1] == n_tokens); GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens); GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens); GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); @@ -5724,6 +5727,9 @@ struct ggml_tensor * ggml_rwkv_wkv7( const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t fuse_exp_i = fuse_exp ? 1 : 0; + ggml_set_op_params(result, &fuse_exp_i, sizeof(fuse_exp_i)); + result->op = GGML_OP_RWKV_WKV7; result->src[0] = r; result->src[1] = w; diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..09bd944b7d 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -65,7 +65,6 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in ggml_tensor * w = ggml_add( ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))), layer.time_mix_w0); - w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531)); ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); @@ -95,14 +94,12 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka)); r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); - w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); - v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); - ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); + ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state, true); cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); @@ -122,6 +119,8 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in } else { cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); } + + v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); ggml_tensor * rk = ggml_sum_rows( ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count))); cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 15567abedc..e6a72a29c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3650,17 +3650,18 @@ struct test_rwkv_wkv7 : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t n_tokens = n_seq_tokens * n_seqs; + const int64_t n_embd = head_count * head_size; ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); - ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * w = ggml_new_tensor(ctx, type, 2, std::vector{ n_embd, n_tokens }.data()); ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); - ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * v = ggml_new_tensor(ctx, type, 2, std::vector{ n_embd, n_tokens }.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); // Outputs may become NaN with long seqlen without these normalization a = ggml_l2_norm(ctx, a, 1e-7F); b = ggml_l2_norm(ctx, b, 1e-7F); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); - ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s); + ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s, true); return out; } };