models: rwkv7: fuse `w` softplus logic into wkv7 op

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2026-01-09 19:28:49 +08:00
parent ae9f8df778
commit f7b238d8ef
13 changed files with 280 additions and 122 deletions

View File

@ -2437,7 +2437,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state);
struct ggml_tensor * state,
bool fuse_exp);
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
* without zeroes on the diagonal (i.e. invertible).

View File

@ -9875,9 +9875,9 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
static void ggml_compute_forward_rwkv_wkv7_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t HEADS = dst->src[4]->ne[1];
const int64_t n_seqs = dst->src[6]->ne[1];
const int64_t head_size = C / HEADS;
@ -9902,6 +9902,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
float * a = (float *) dst->src[4]->data;
float * b = (float *) dst->src[5]->data;
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
int64_t t_stride = HEADS * head_size; // Same to C
int64_t h_stride = C / HEADS;
@ -9938,7 +9941,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;
@ -9997,6 +10000,11 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
if (fuse_exp) {
w_vec = ggml_v_sigmoid(w_vec, w_scale);
w_vec = ggml_v_expf(w_vec);
}
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
@ -10016,7 +10024,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v[t_h_i_offset] * k_val;
@ -10057,7 +10065,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;

View File

@ -1129,16 +1129,27 @@ inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
// computes (1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_one_plus_exp_neg_x(svbool_t pg, svfloat32_t x) {
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
return svadd_f32_x(pg, one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static svfloat32_t ggml_v_sigmoid(svbool_t pg, svfloat32_t x, float scale) {
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
return svdiv_f32_x(pg, svdup_n_f32_x(pg, scale), one_plus_exp_neg_x);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine
@ -1168,16 +1179,27 @@ inline static float32x4_t ggml_v_expf(float32x4_t x) {
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_silu(float32x4_t x) {
// computes (1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_one_plus_exp_neg_x(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f);
const float32x4_t neg_x = vsubq_f32(zero, x);
const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vaddq_f32(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_silu(float32x4_t x) {
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static float32x4_t ggml_v_sigmoid(float32x4_t x, float scale) {
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return vdivq_f32(vdupq_n_f32(scale), one_plus_exp_neg_x);
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
// adapted from arm limited optimized routine
@ -1211,16 +1233,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
return _mm512_mask_blend_ps(d, res, alt);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m512 ggml_v_silu(__m512 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m512 ggml_v_one_plus_exp_neg_x(__m512 x) {
const __m512 one = _mm512_set1_ps(1);
const __m512 zero = _mm512_setzero_ps();
const __m512 neg_x = _mm512_sub_ps(zero, x);
const __m512 exp_neg_x = ggml_v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m512 ggml_v_silu(__m512 x) {
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m512 ggml_v_sigmoid(__m512 x, float scale) {
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm512_div_ps(_mm512_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
@ -1266,16 +1299,27 @@ inline static __m256 ggml_v_expf(__m256 x) {
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m256 ggml_v_silu(__m256 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m256 ggml_v_one_plus_exp_neg_x(__m256 x) {
const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = ggml_v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m256 ggml_v_silu(__m256 x) {
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m256 ggml_v_sigmoid(__m256 x, float scale) {
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm256_div_ps(_mm256_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
@ -1320,16 +1364,27 @@ inline static __m128 ggml_v_expf(__m128 x) {
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m128 ggml_v_silu(__m128 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m128 ggml_v_one_plus_exp_neg_x(__m128 x) {
const __m128 one = _mm_set1_ps(1);
const __m128 zero = _mm_setzero_ps();
const __m128 neg_x = _mm_sub_ps(zero, x);
const __m128 exp_neg_x = ggml_v_expf(neg_x);
const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
return _mm_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m128 ggml_v_silu(__m128 x) {
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m128 ggml_v_sigmoid(__m128 x, float scale) {
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm_div_ps(_mm_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__riscv_v_intrinsic)
// adapted from arm limited optimized routine
@ -1374,14 +1429,25 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
vl);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
// computes (1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_one_plus_exp_neg_x_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
return __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
}
// computes sigmoid 1/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_sigmoid_m2(vfloat32m2_t x, int vl, float scale) {
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
return __riscv_vfrdiv_vf_f32m2(one_plus_exp_neg_x, scale, vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {

View File

@ -65,7 +65,9 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
}
}
template <int block_size>
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
template <int block_size, bool fuse_exp>
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
@ -89,7 +91,7 @@ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, cons
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? __expf(w_scale / (1.0f + __expf(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
@ -179,9 +181,11 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const float * s_d = (const float *)dst->src[6]->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
const int64_t H = dst->src[4]->ne[1];
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
float * dst_d = (float *)dst->data;
@ -192,8 +196,16 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
if (fuse_exp) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
if (fuse_exp) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
}
}

View File

@ -1519,6 +1519,8 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t fuse_exp = ((const int32_t *) op->op_params)[0];
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
const int64_t T = op->src[0]->ne[2];
const int64_t C = op->ne[0];
@ -1543,6 +1545,9 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
if (op->op == GGML_OP_RWKV_WKV7) {
ggml_metal_encoder_set_bytes (enc, (void *) &fuse_exp, sizeof(fuse_exp), ida++);
}
ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);

View File

@ -2657,6 +2657,7 @@ kernel void kernel_rwkv_wkv7_f32(
constant uint & T,
constant uint & C,
constant uint & H,
constant uint & fuse_exp,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@ -2666,6 +2667,8 @@ kernel void kernel_rwkv_wkv7_f32(
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
if (batch_id >= B || head_id >= H) {
return;
}
@ -2692,7 +2695,7 @@ kernel void kernel_rwkv_wkv7_f32(
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? exp(w_scale / (1 + exp(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];

View File

@ -96,7 +96,9 @@ static void rwkv_wkv6_f32_kernel(
}
}
template <int block_size>
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
template <int block_size, bool fuse_exp>
static void rwkv_wkv7_f32_kernel(
const int B, const int T, const int C, const int H,
const float* r, const float* w, const float* k, const float* v,
@ -132,7 +134,7 @@ static void rwkv_wkv7_f32_kernel(
item_ct1.barrier(sycl::access::fence_space::local_space);
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? sycl::native::exp(w_scale / (1.0f + sycl::native::exp(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
@ -247,9 +249,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
const int64_t H = dst->src[4]->ne[1];
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
@ -264,30 +268,61 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
if (fuse_exp) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, true>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, false>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
});
});
}
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
if (fuse_exp) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, true>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, false>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
}
}

View File

@ -780,6 +780,7 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_rwkv_wkv7_f32_fuse_exp;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
@ -4299,6 +4300,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32_fuse_exp, "rwkv_wkv7_f32_fuse_exp", rwkv_wkv7_f32_fuse_exp_len, rwkv_wkv7_f32_fuse_exp_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
@ -8992,6 +8994,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (dst->op_params[0] == 1) {
return ctx->device->pipeline_rwkv_wkv7_f32_fuse_exp;
}
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
@ -9748,10 +9753,62 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
});
}
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
const vk_op_rwkv_wkv6_push_constants pc = {
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
};
const int num_srcs = 6;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
for (int i = 0; i < num_srcs; i++) {
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
}
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
const vk_op_rwkv_wkv7_push_constants pc = {
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
};
const int num_srcs = 7;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
@ -9775,54 +9832,9 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
1
};
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
6
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7
);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
pc, elements);
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@ -15558,7 +15570,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
src_clone[4], src_clone[5], src_clone[6], (uint32_t) tensor->op_params[0]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],

View File

@ -967,7 +967,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "0"}}));
string_to_spv("rwkv_wkv7_f32_fuse_exp", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "1"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@ -10,6 +10,7 @@ layout(push_constant) uniform Parameters {
uint T;
uint C;
uint H;
uint fuse_exp;
};
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
@ -45,10 +46,18 @@ void main() {
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
#if FUSE_EXP
const float w_scale = -0.6065306597f; // -exp(-0.5)
#endif
for (uint t = start_t; t < end_t; t += C) {
barrier();
_r[tid] = r[t];
#if FUSE_EXP
_w[tid] = exp(w_scale / (1 + exp(-w[t])));
#else
_w[tid] = w[t];
#endif
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];

View File

@ -5698,7 +5698,8 @@ struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state) {
struct ggml_tensor * state,
bool fuse_exp) {
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(w));
GGML_ASSERT(ggml_is_contiguous(k));
@ -5707,14 +5708,16 @@ struct ggml_tensor * ggml_rwkv_wkv7(
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t S = a->ne[0];
const int64_t H = a->ne[1];
const int64_t n_tokens = a->ne[2];
const int64_t n_seqs = state->ne[1];
const int64_t n_embd = S * H;
{
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
GGML_ASSERT(w->ne[0] == n_embd && w->ne[1] == n_tokens);
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == n_embd && v->ne[1] == n_tokens);
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
@ -5724,6 +5727,9 @@ struct ggml_tensor * ggml_rwkv_wkv7(
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t fuse_exp_i = fuse_exp ? 1 : 0;
ggml_set_op_params(result, &fuse_exp_i, sizeof(fuse_exp_i));
result->op = GGML_OP_RWKV_WKV7;
result->src[0] = r;
result->src[1] = w;

View File

@ -65,7 +65,6 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
ggml_tensor * w = ggml_add(
ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
layer.time_mix_w0);
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
@ -95,14 +94,12 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state, true);
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
@ -122,6 +119,8 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
} else {
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
}
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
ggml_tensor * rk = ggml_sum_rows(
ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));

View File

@ -3650,17 +3650,18 @@ struct test_rwkv_wkv7 : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
const int64_t n_embd = head_count * head_size;
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
// Outputs may become NaN with long seqlen without these normalization
a = ggml_l2_norm(ctx, a, 1e-7F);
b = ggml_l2_norm(ctx, b, 1e-7F);
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s, true);
return out;
}
};