This commit is contained in:
Molly Sophia 2026-02-01 15:44:35 -08:00 committed by GitHub
commit d7d754e43b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 901 additions and 144 deletions

View File

@ -556,6 +556,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_LERP,
GGML_OP_UNARY,
@ -583,6 +584,7 @@ extern "C" {
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
GGML_UNARY_OP_RELU_SQR,
GGML_UNARY_OP_SIGMOID,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
@ -1133,6 +1135,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_sqr_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a);
@ -2443,7 +2453,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state);
struct ggml_tensor * state,
bool fuse_exp);
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
* without zeroes on the diagonal (i.e. invertible).
@ -2466,6 +2477,14 @@ extern "C" {
bool lower,
bool uni);
// a + (b - a) * t
// used in rwkv7
GGML_API struct ggml_tensor * ggml_lerp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * t);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2019,6 +2019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_LERP:
{
ggml_compute_forward_lerp(params, tensor);
}
break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2180,6 +2185,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_FILL:
case GGML_OP_LERP:
{
n_tasks = n_threads;
} break;
@ -2216,6 +2222,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_RELU_SQR:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:

View File

@ -9458,6 +9458,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_relu(params, dst);
} break;
case GGML_UNARY_OP_RELU_SQR:
{
ggml_compute_forward_relu_sqr(params, dst);
} break;
case GGML_UNARY_OP_SIGMOID:
{
ggml_compute_forward_sigmoid(params, dst);
@ -10189,9 +10193,9 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
static void ggml_compute_forward_rwkv_wkv7_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const int64_t T = dst->src[1]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[1];
const int64_t HEADS = dst->src[4]->ne[1];
const int64_t n_seqs = dst->src[6]->ne[1];
const int64_t head_size = C / HEADS;
@ -10216,6 +10220,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
float * a = (float *) dst->src[4]->data;
float * b = (float *) dst->src[5]->data;
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
int64_t t_stride = HEADS * head_size; // Same to C
int64_t h_stride = C / HEADS;
@ -10252,7 +10259,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;
@ -10311,6 +10318,11 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
if (fuse_exp) {
w_vec = ggml_v_sigmoid(w_vec, w_scale);
w_vec = ggml_v_expf(w_vec);
}
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
@ -10330,7 +10342,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v[t_h_i_offset] * k_val;
@ -10371,7 +10383,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
float r_val = r[t_h_j_offset];
float w_val = w[t_h_j_offset];
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
float k_val = k[t_h_j_offset];
float b_val = b[t_h_j_offset];
float kv_val = v_val * k_val;
@ -10405,6 +10417,177 @@ void ggml_compute_forward_rwkv_wkv7(
}
}
static void ggml_compute_forward_lerp_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(float));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne23 % ne03 == 0);
GGML_ASSERT(ne23 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
const int nr = ggml_nrows(dst);
const int dr = (nr + nth - 1)/nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int i03 = i3 % ne03;
const int i13 = i3 % ne13;
const int i21 = i1 % ne21;
const int i22 = i2 % ne22;
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const float s0 = src0_ptr[i0];
const float s1 = src1_ptr[i0];
const float s2 = src2_ptr[i0];
dst_ptr[i0] = s0 + (s1 - s0) * s2;
}
}
}
static void ggml_compute_forward_lerp_f32_f32_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F16);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(ggml_fp16_t));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne23 % ne03 == 0);
GGML_ASSERT(ne23 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
const int nr = ggml_nrows(dst);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int i03 = i3 % ne03;
const int i13 = i3 % ne13;
const int i21 = i1 % ne21;
const int i22 = i2 % ne22;
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const float s0 = src0_ptr[i0];
const float s1 = src1_ptr[i0];
const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]);
dst_ptr[i0] = s0 + (s1 - s0) * s2;
}
}
}
void ggml_compute_forward_lerp(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src2 = dst->src[2];
switch (src0->type) {
case GGML_TYPE_F32:
{
if (src2->type == GGML_TYPE_F32) {
ggml_compute_forward_lerp_f32(params, dst);
} else if (src2->type == GGML_TYPE_F16) {
ggml_compute_forward_lerp_f32_f32_f16(params, dst);
} else {
GGML_ABORT("fatal error");
}
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_custom1
void ggml_compute_forward_map_custom1(

View File

@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -28,6 +28,11 @@ static inline float op_relu(float x) {
return (x > 0.f) ? x : 0.f;
}
static inline float op_relu_sqr(float x) {
float r = (x > 0.f) ? x : 0.f;
return r * r;
}
static inline float op_sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor *
unary_op<op_relu>(params, dst);
}
void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_relu_sqr>(params, dst);
}
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sigmoid>(params, dst);
}

View File

@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -1129,16 +1129,27 @@ inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
// computes (1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_one_plus_exp_neg_x(svbool_t pg, svfloat32_t x) {
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
return svadd_f32_x(pg, one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static svfloat32_t ggml_v_sigmoid(svbool_t pg, svfloat32_t x, float scale) {
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
return svdiv_f32_x(pg, svdup_n_f32_x(pg, scale), one_plus_exp_neg_x);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine
@ -1168,16 +1179,27 @@ inline static float32x4_t ggml_v_expf(float32x4_t x) {
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_silu(float32x4_t x) {
// computes (1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_one_plus_exp_neg_x(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f);
const float32x4_t neg_x = vsubq_f32(zero, x);
const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vaddq_f32(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static float32x4_t ggml_v_silu(float32x4_t x) {
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static float32x4_t ggml_v_sigmoid(float32x4_t x, float scale) {
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return vdivq_f32(vdupq_n_f32(scale), one_plus_exp_neg_x);
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
// adapted from arm limited optimized routine
@ -1211,16 +1233,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
return _mm512_mask_blend_ps(d, res, alt);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m512 ggml_v_silu(__m512 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m512 ggml_v_one_plus_exp_neg_x(__m512 x) {
const __m512 one = _mm512_set1_ps(1);
const __m512 zero = _mm512_setzero_ps();
const __m512 neg_x = _mm512_sub_ps(zero, x);
const __m512 exp_neg_x = ggml_v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m512 ggml_v_silu(__m512 x) {
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m512 ggml_v_sigmoid(__m512 x, float scale) {
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm512_div_ps(_mm512_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
@ -1266,16 +1299,27 @@ inline static __m256 ggml_v_expf(__m256 x) {
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m256 ggml_v_silu(__m256 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m256 ggml_v_one_plus_exp_neg_x(__m256 x) {
const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = ggml_v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m256 ggml_v_silu(__m256 x) {
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m256 ggml_v_sigmoid(__m256 x, float scale) {
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm256_div_ps(_mm256_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
@ -1320,16 +1364,27 @@ inline static __m128 ggml_v_expf(__m128 x) {
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m128 ggml_v_silu(__m128 x) {
// computes (1+exp(-x)) in single precision vector
inline static __m128 ggml_v_one_plus_exp_neg_x(__m128 x) {
const __m128 one = _mm_set1_ps(1);
const __m128 zero = _mm_setzero_ps();
const __m128 neg_x = _mm_sub_ps(zero, x);
const __m128 exp_neg_x = ggml_v_expf(neg_x);
const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
return _mm_add_ps(one, exp_neg_x);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static __m128 ggml_v_silu(__m128 x) {
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm_div_ps(x, one_plus_exp_neg_x);
}
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
inline static __m128 ggml_v_sigmoid(__m128 x, float scale) {
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
return _mm_div_ps(_mm_set1_ps(scale), one_plus_exp_neg_x);
}
#elif defined(__riscv_v_intrinsic)
// adapted from arm limited optimized routine
@ -1374,14 +1429,25 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
vl);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
// computes (1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_one_plus_exp_neg_x_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
return __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
}
// computes sigmoid 1/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_sigmoid_m2(vfloat32m2_t x, int vl, float scale) {
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
return __riscv_vfrdiv_vf_f32m2(one_plus_exp_neg_x, scale, vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {

View File

@ -52,6 +52,7 @@
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/lerp.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
@ -2504,6 +2505,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_RELU:
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_RELU_SQR:
ggml_cuda_op_relu_sqr(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
ggml_cuda_op_sigmoid(ctx, dst);
break;
@ -2727,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_LERP:
ggml_cuda_op_lerp(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -4522,6 +4529,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_RELU_SQR:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
@ -4839,6 +4847,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LERP:
return true;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);

283
ggml/src/ggml-cuda/lerp.cu Normal file
View File

@ -0,0 +1,283 @@
#include "lerp.cuh"
#include <cstdint>
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static __global__ void k_lerp(
const src0_t * src0,
const src1_t * src1,
const src2_t * src2,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne03,
const uint3 ne13,
const uint3 ne21,
const uint3 ne22,
const int s1,
const int s2,
const int s3,
const int s01,
const int s02,
const int s03,
const int s11,
const int s12,
const int s13,
const int s21,
const int s22,
const int s23) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
return;
}
// src0/src1 broadcast in dim3
const uint32_t i03 = fastmodulo(i3, ne03);
const uint32_t i13 = fastmodulo(i3, ne13);
// src2 broadcast in dim1, dim2
const uint32_t i21 = fastmodulo(i1, ne21);
const uint32_t i22 = fastmodulo(i2, ne22);
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src2_t * src2_row = src2 + i_src2;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const float v0 = (float) src0_row[i0];
const float v1 = (float) src1_row[i0];
const float v2 = (float) src2_row[i0];
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
}
}
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static __global__ void k_lerp_unravel(
const src0_t * src0,
const src1_t * src1,
const src2_t * src2,
dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne03,
const uint3 ne13,
const uint3 ne21,
const uint3 ne22,
const int s1,
const int s2,
const int s3,
const int s01,
const int s02,
const int s03,
const int s11,
const int s12,
const int s13,
const int s21,
const int s22,
const int s23) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = fastdiv(i, prod_012);
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return;
}
// src0/src1 broadcast in dim3
const int i03 = fastmodulo(i3, ne03);
const int i13 = fastmodulo(i3, ne13);
// src2 broadcast in dim1, dim2
const int i21 = fastmodulo(i1, ne21);
const int i22 = fastmodulo(i2, ne22);
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src2_t * src2_row = src2 + i_src2;
dst_t * dst_row = dst + i_dst;
const float v0 = (float) src0_row[i0];
const float v1 = (float) src1_row[i0];
const float v2 = (float) src2_row[i0];
// dst = src0 + (src1 - src0) * src2
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
}
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static void launch_lerp(
const ggml_tensor * src0,
const ggml_tensor * src1,
const ggml_tensor * src2,
ggml_tensor * dst,
const src0_t * src0_dd,
const src1_t * src1_dd,
const src2_t * src2_dd,
dst_t * dst_dd,
cudaStream_t stream) {
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
GGML_ASSERT(ne3 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s21 = nb21 / sizeof(src2_t);
size_t s22 = nb22 / sizeof(src2_t);
size_t s23 = nb23 / sizeof(src2_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(nb20 % sizeof(src2_t) == 0);
GGML_ASSERT(nb21 % sizeof(src2_t) == 0);
GGML_ASSERT(nb22 % sizeof(src2_t) == 0);
GGML_ASSERT(nb23 % sizeof(src2_t) == 0);
const int block_size = CUDA_LERP_BLOCK_SIZE;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums(
(hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03);
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13);
const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21);
const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22);
if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
k_lerp_unravel<src0_t, src1_t, src2_t, dst_t>
<<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, src2_dd, dst_dd,
ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3,
prod_012, prod_01,
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
s1, s2, s3,
s01, s02, s03,
s11, s12, s13,
s21, s22, s23);
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
k_lerp<src0_t, src1_t, src2_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, src2_dd, dst_dd,
ne0, ne1, ne2, ne3_fastdiv,
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
s1, s2, s3,
s01, s02, s03,
s11, s12, s13,
s21, s22, s23);
}
}
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (src2->type == GGML_TYPE_F32) {
launch_lerp<float, float, float, float>(
src0, src1, src2, dst,
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(float *) dst->data,
stream);
} else if (src2->type == GGML_TYPE_F16) {
launch_lerp<float, float, half, float>(
src0, src1, src2, dst,
(const float *) src0->data,
(const float *) src1->data,
(const half *) src2->data,
(float *) dst->data,
stream);
} else {
GGML_ABORT("fatal error");
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_LERP_BLOCK_SIZE 256
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) {
return fmaxf(x, 0);
}
static __device__ __forceinline__ float op_relu_sqr(float x) {
float r = fmaxf(x, 0);
return r * r;
}
static __device__ __forceinline__ float op_sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}
@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_relu>(ctx, dst);
}
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_relu_sqr>(ctx, dst);
}
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
}

View File

@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -65,7 +65,9 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
}
}
template <int block_size>
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
template <int block_size, bool fuse_exp>
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
@ -89,7 +91,7 @@ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, cons
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? __expf(w_scale / (1.0f + __expf(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
@ -179,9 +181,11 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const float * s_d = (const float *)dst->src[6]->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
const int64_t H = dst->src[4]->ne[1];
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
float * dst_d = (float *)dst->data;
@ -192,8 +196,16 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
if (C / H == CUDA_WKV_BLOCK_SIZE) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
if (fuse_exp) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
if (fuse_exp) {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
} else {
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
}
}
}

View File

@ -1527,6 +1527,8 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t fuse_exp = ((const int32_t *) op->op_params)[0];
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
const int64_t T = op->src[0]->ne[2];
const int64_t C = op->ne[0];
@ -1551,6 +1553,9 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
if (op->op == GGML_OP_RWKV_WKV7) {
ggml_metal_encoder_set_bytes (enc, (void *) &fuse_exp, sizeof(fuse_exp), ida++);
}
ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);

View File

@ -2657,6 +2657,7 @@ kernel void kernel_rwkv_wkv7_f32(
constant uint & T,
constant uint & C,
constant uint & H,
constant uint & fuse_exp,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@ -2666,6 +2667,8 @@ kernel void kernel_rwkv_wkv7_f32(
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
if (batch_id >= B || head_id >= H) {
return;
}
@ -2692,7 +2695,7 @@ kernel void kernel_rwkv_wkv7_f32(
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? exp(w_scale / (1 + exp(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];

View File

@ -96,7 +96,9 @@ static void rwkv_wkv6_f32_kernel(
}
}
template <int block_size>
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
template <int block_size, bool fuse_exp>
static void rwkv_wkv7_f32_kernel(
const int B, const int T, const int C, const int H,
const float* r, const float* w, const float* k, const float* v,
@ -132,7 +134,7 @@ static void rwkv_wkv7_f32_kernel(
item_ct1.barrier(sycl::access::fence_space::local_space);
_r[tid] = r[t];
_w[tid] = w[t];
_w[tid] = fuse_exp ? sycl::native::exp(w_scale / (1.0f + sycl::native::exp(-w[t]))) : w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
@ -247,9 +249,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t T = dst->src[4]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
const int64_t H = dst->src[4]->ne[1];
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
@ -264,30 +268,61 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
if (fuse_exp) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, true>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, false>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
});
});
}
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
if (fuse_exp) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, true>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, false>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
}
}
}

View File

@ -789,6 +789,7 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_rwkv_wkv7_f32_fuse_exp;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
@ -4362,6 +4363,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32_fuse_exp, "rwkv_wkv7_f32_fuse_exp", rwkv_wkv7_f32_fuse_exp_len, rwkv_wkv7_f32_fuse_exp_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
@ -9143,6 +9145,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (dst->op_params[0] == 1) {
return ctx->device->pipeline_rwkv_wkv7_f32_fuse_exp;
}
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
@ -9899,10 +9904,62 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
});
}
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
const vk_op_rwkv_wkv6_push_constants pc = {
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
};
const int num_srcs = 6;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
for (int i = 0; i < num_srcs; i++) {
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
}
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
const vk_op_rwkv_wkv7_push_constants pc = {
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
};
const int num_srcs = 7;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
@ -9926,54 +9983,9 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
1
};
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
6
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7
);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
pc, elements);
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@ -15734,7 +15746,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
src_clone[4], src_clone[5], src_clone[6], (uint32_t) tensor->op_params[0]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],

View File

@ -967,7 +967,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "0"}}));
string_to_spv("rwkv_wkv7_f32_fuse_exp", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "1"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@ -10,6 +10,7 @@ layout(push_constant) uniform Parameters {
uint T;
uint C;
uint H;
uint fuse_exp;
};
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
@ -45,10 +46,18 @@ void main() {
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
#if FUSE_EXP
const float w_scale = -0.6065306597f; // -exp(-0.5)
#endif
for (uint t = start_t; t < end_t; t += C) {
barrier();
_r[tid] = r[t];
#if FUSE_EXP
_w[tid] = exp(w_scale / (1 + exp(-w[t])));
#else
_w[tid] = w[t];
#endif
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];

View File

@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"LERP",
"UNARY",
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"x+(y-x)*t",
"unary(x)",
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -1168,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TANH",
"ELU",
"RELU",
"RELU_SQR",
"SIGMOID",
"GELU",
"GELU_QUICK",
@ -1185,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TRUNC",
};
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
@ -2643,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
}
// ggml_relu_sqr
struct ggml_tensor * ggml_relu_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR);
}
struct ggml_tensor * ggml_relu_sqr_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR);
}
// ggml_leaky_relu
struct ggml_tensor * ggml_leaky_relu(
@ -5704,7 +5721,8 @@ struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state) {
struct ggml_tensor * state,
bool fuse_exp) {
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(w));
GGML_ASSERT(ggml_is_contiguous(k));
@ -5713,14 +5731,16 @@ struct ggml_tensor * ggml_rwkv_wkv7(
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t S = a->ne[0];
const int64_t H = a->ne[1];
const int64_t n_tokens = a->ne[2];
const int64_t n_seqs = state->ne[1];
const int64_t n_embd = S * H;
{
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
GGML_ASSERT(w->ne[0] == n_embd && w->ne[1] == n_tokens);
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == n_embd && v->ne[1] == n_tokens);
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
@ -5730,6 +5750,9 @@ struct ggml_tensor * ggml_rwkv_wkv7(
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t fuse_exp_i = fuse_exp ? 1 : 0;
ggml_set_op_params(result, &fuse_exp_i, sizeof(fuse_exp_i));
result->op = GGML_OP_RWKV_WKV7;
result->src[0] = r;
result->src[1] = w;
@ -5742,6 +5765,34 @@ struct ggml_tensor * ggml_rwkv_wkv7(
return result;
}
// ggml_lerp
struct ggml_tensor * ggml_lerp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * t) {
// assume a and b are the same shape for now
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(t->ne[0] == a->ne[0]);
GGML_ASSERT(a->ne[1] % t->ne[1] == 0);
GGML_ASSERT(a->ne[2] % t->ne[2] == 0);
// a/b can broadcast to t at dim3 for rwkv7
GGML_ASSERT(t->ne[3] % a->ne[3] == 0);
const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_LERP;
result->src[0] = a;
result->src[1] = b;
result->src[2] = t;
return result;
}
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(

View File

@ -1001,10 +1001,7 @@ ggml_tensor * llm_graph_context::build_ffn(
} break;
case LLM_FFN_RELU_SQR:
{
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_relu", il);
cur = ggml_sqr(ctx0, cur);
cur = ggml_relu_sqr(ctx0, cur);
cb(cur, "ffn_sqr(relu)", il);
} break;
case LLM_FFN_SWIGLU:
@ -1307,8 +1304,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {
cur = ggml_relu(ctx0, cur);
cur = ggml_sqr(ctx0, cur);
cur = ggml_relu_sqr(ctx0, cur);
cb(cur, "ffn_moe_relu_sqr", il);
} break;
default:

View File

@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
ggml_tensor * cur,
ggml_tensor * x_prev,
llm_arch arch) const;
llm_arch arch,
int il) const;
ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
ggml_tensor * cur,
ggml_tensor * x_prev,

View File

@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_
ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
ggml_tensor * cur,
ggml_tensor * x_prev,
llm_arch arch) const {
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
llm_arch arch,
int il) const {
switch (arch) {
case LLM_ARCH_RWKV7:
{
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
// cur + (x_prev - cur) * layer->channel_mix_lerp_k
ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k);
ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
cur = build_lora_mm(layer->channel_mix_value, k);
cur = build_ffn(
xk,
layer->channel_mix_key, nullptr, nullptr, // up
nullptr, nullptr, nullptr, // gate
layer->channel_mix_value, nullptr, nullptr, // down
nullptr,
LLM_FFN_RELU_SQR,
LLM_FFN_SEQ,
il
);
}
break;
default:
@ -46,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
sx = ggml_repeat(ctx0, sx, dummy);
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused);
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
@ -65,7 +69,6 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
ggml_tensor * w = ggml_add(
ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
layer.time_mix_w0);
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
@ -95,14 +98,12 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state, true);
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
@ -122,6 +123,8 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
} else {
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
}
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
ggml_tensor * rk = ggml_sum_rows(
ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));

View File

@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
}
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);

View File

@ -3681,17 +3681,47 @@ struct test_rwkv_wkv7 : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
const int64_t n_embd = head_count * head_size;
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * w = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
// Outputs may become NaN with long seqlen without these normalization
a = ggml_l2_norm(ctx, a, 1e-7F);
b = ggml_l2_norm(ctx, b, 1e-7F);
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s, true);
return out;
}
};
// GGML_OP_LERP
struct test_lerp : public test_case {
const ggml_type type_t;
const ggml_type type_a;
const ggml_type type_b;
const std::array<int64_t, 4> ne0;
const std::array<int64_t, 4> ne1;
const std::array<int64_t, 4> ne2;
std::string vars() override {
return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2);
}
test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
ggml_type type_t = GGML_TYPE_F32,
std::array<int64_t, 4> ne0 = {10, 10, 1, 1},
std::array<int64_t, 4> ne1 = {10, 10, 1, 1},
std::array<int64_t, 4> ne2 = {10, 10, 1, 1})
: type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data());
ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data());
ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data());
ggml_tensor * out = ggml_lerp(ctx, a, b, c);
return out;
}
};
@ -7612,6 +7642,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
#if 0
// > 4GB A matrix. Too slow to be enabled by default.
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));