diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1988d16dc4..1553c2b93c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -556,6 +556,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_LERP, GGML_OP_UNARY, @@ -583,6 +584,7 @@ extern "C" { GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, GGML_UNARY_OP_RELU, + GGML_UNARY_OP_RELU_SQR, GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU_QUICK, @@ -1133,6 +1135,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_relu_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid( struct ggml_context * ctx, struct ggml_tensor * a); @@ -2443,7 +2453,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * state); + struct ggml_tensor * state, + bool fuse_exp); /* Solves a specific equation of the form Ax=B, where A is a triangular matrix * without zeroes on the diagonal (i.e. invertible). @@ -2466,6 +2477,14 @@ extern "C" { bool lower, bool uni); + // a + (b - a) * t + // used in rwkv7 + GGML_API struct ggml_tensor * ggml_lerp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * t); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b1de2ae871..033dbc24e7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2019,6 +2019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_LERP: + { + ggml_compute_forward_lerp(params, tensor); + } + break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2180,6 +2185,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CUMSUM: case GGML_OP_TRI: case GGML_OP_FILL: + case GGML_OP_LERP: { n_tasks = n_threads; } break; @@ -2216,6 +2222,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_RELU_SQR: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 48c8964361..53f8cc413d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9458,6 +9458,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_relu(params, dst); } break; + case GGML_UNARY_OP_RELU_SQR: + { + ggml_compute_forward_relu_sqr(params, dst); + } break; case GGML_UNARY_OP_SIGMOID: { ggml_compute_forward_sigmoid(params, dst); @@ -10189,9 +10193,9 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s static void ggml_compute_forward_rwkv_wkv7_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const int64_t T = dst->src[1]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t HEADS = dst->src[4]->ne[1]; const int64_t n_seqs = dst->src[6]->ne[1]; const int64_t head_size = C / HEADS; @@ -10216,6 +10220,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( float * a = (float *) dst->src[4]->data; float * b = (float *) dst->src[5]->data; + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; + constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + int64_t t_stride = HEADS * head_size; // Same to C int64_t h_stride = C / HEADS; @@ -10252,7 +10259,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v_val * k_val; @@ -10311,6 +10318,11 @@ static void ggml_compute_forward_rwkv_wkv7_f32( GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + if (fuse_exp) { + w_vec = ggml_v_sigmoid(w_vec, w_scale); + w_vec = ggml_v_expf(w_vec); + } + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); @@ -10330,7 +10342,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v[t_h_i_offset] * k_val; @@ -10371,7 +10383,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_2d_i_j_offset = h_2d_i_offset + j; float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; + float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset]; float k_val = k[t_h_j_offset]; float b_val = b[t_h_j_offset]; float kv_val = v_val * k_val; @@ -10405,6 +10417,177 @@ void ggml_compute_forward_rwkv_wkv7( } } +static void ggml_compute_forward_lerp_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(float)); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + + GGML_ASSERT(ne23 % ne03 == 0); + GGML_ASSERT(ne23 % ne13 == 0); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + const int nr = ggml_nrows(dst); + + const int dr = (nr + nth - 1)/nth; + + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int i03 = i3 % ne03; + const int i13 = i3 % ne13; + + const int i21 = i1 % ne21; + const int i22 = i2 % ne22; + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01); + const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11); + const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const float s0 = src0_ptr[i0]; + const float s1 = src1_ptr[i0]; + const float s2 = src2_ptr[i0]; + + dst_ptr[i0] = s0 + (s1 - s0) * s2; + } + } +} + +static void ggml_compute_forward_lerp_f32_f32_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F16); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + + GGML_ASSERT(ne23 % ne03 == 0); + GGML_ASSERT(ne23 % ne13 == 0); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int i03 = i3 % ne03; + const int i13 = i3 % ne13; + + const int i21 = i1 % ne21; + const int i22 = i2 % ne22; + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01); + const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11); + const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const float s0 = src0_ptr[i0]; + const float s1 = src1_ptr[i0]; + const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]); + + dst_ptr[i0] = s0 + (s1 - s0) * s2; + } + } +} + +void ggml_compute_forward_lerp( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src2 = dst->src[2]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (src2->type == GGML_TYPE_F32) { + ggml_compute_forward_lerp_f32(params, dst); + } else if (src2->type == GGML_TYPE_F16) { + ggml_compute_forward_lerp_f32_f32_f16(params, dst); + } else { + GGML_ABORT("fatal error"); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..c048b4b272 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f..505b24ce2e 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -28,6 +28,11 @@ static inline float op_relu(float x) { return (x > 0.f) ? x : 0.f; } +static inline float op_relu_sqr(float x) { + float r = (x > 0.f) ? x : 0.f; + return r * r; +} + static inline float op_sigmoid(float x) { return 1.f / (1.f + expf(-x)); } @@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index bcad5a3af1..0229950804 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b50..5f22b4d489 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1129,16 +1129,27 @@ inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) { svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { +// computes (1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_one_plus_exp_neg_x(svbool_t pg, svfloat32_t x) { const svfloat32_t one = svdup_n_f32_x(pg, 1.0f); const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f); const svfloat32_t neg_x = svsub_f32_x(pg, zero, x); const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x); - const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x); + return svadd_f32_x(pg, one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { + const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x); return svdiv_f32_x(pg, x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static svfloat32_t ggml_v_sigmoid(svbool_t pg, svfloat32_t x, float scale) { + const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x); + return svdiv_f32_x(pg, svdup_n_f32_x(pg, scale), one_plus_exp_neg_x); +} + #elif defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine @@ -1168,16 +1179,27 @@ inline static float32x4_t ggml_v_expf(float32x4_t x) { vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static float32x4_t ggml_v_silu(float32x4_t x) { +// computes (1+exp(-x)) in single precision vector +inline static float32x4_t ggml_v_one_plus_exp_neg_x(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); const float32x4_t neg_x = vsubq_f32(zero, x); const float32x4_t exp_neg_x = ggml_v_expf(neg_x); - const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vaddq_f32(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static float32x4_t ggml_v_silu(float32x4_t x) { + const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return vdivq_f32(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static float32x4_t ggml_v_sigmoid(float32x4_t x, float scale) { + const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return vdivq_f32(vdupq_n_f32(scale), one_plus_exp_neg_x); +} + #elif defined(__AVX512F__) && defined(__AVX512DQ__) // adapted from arm limited optimized routine @@ -1211,16 +1233,27 @@ inline static __m512 ggml_v_expf(__m512 x) { return _mm512_mask_blend_ps(d, res, alt); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m512 ggml_v_silu(__m512 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m512 ggml_v_one_plus_exp_neg_x(__m512 x) { const __m512 one = _mm512_set1_ps(1); const __m512 zero = _mm512_setzero_ps(); const __m512 neg_x = _mm512_sub_ps(zero, x); const __m512 exp_neg_x = ggml_v_expf(neg_x); - const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m512 ggml_v_silu(__m512 x) { + const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm512_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m512 ggml_v_sigmoid(__m512 x, float scale) { + const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm512_div_ps(_mm512_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__AVX2__) && defined(__FMA__) // adapted from arm limited optimized routine @@ -1266,16 +1299,27 @@ inline static __m256 ggml_v_expf(__m256 x) { _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m256 ggml_v_silu(__m256 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m256 ggml_v_one_plus_exp_neg_x(__m256 x) { const __m256 one = _mm256_set1_ps(1); const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = ggml_v_expf(neg_x); - const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m256 ggml_v_silu(__m256 x) { + const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm256_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m256 ggml_v_sigmoid(__m256 x, float scale) { + const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm256_div_ps(_mm256_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON #if defined(__FMA__) @@ -1320,16 +1364,27 @@ inline static __m128 ggml_v_expf(__m128 x) { _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m128 ggml_v_silu(__m128 x) { +// computes (1+exp(-x)) in single precision vector +inline static __m128 ggml_v_one_plus_exp_neg_x(__m128 x) { const __m128 one = _mm_set1_ps(1); const __m128 zero = _mm_setzero_ps(); const __m128 neg_x = _mm_sub_ps(zero, x); const __m128 exp_neg_x = ggml_v_expf(neg_x); - const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + return _mm_add_ps(one, exp_neg_x); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m128 ggml_v_silu(__m128 x) { + const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); return _mm_div_ps(x, one_plus_exp_neg_x); } +// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector +inline static __m128 ggml_v_sigmoid(__m128 x, float scale) { + const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x); + return _mm_div_ps(_mm_set1_ps(scale), one_plus_exp_neg_x); +} + #elif defined(__riscv_v_intrinsic) // adapted from arm limited optimized routine @@ -1374,14 +1429,25 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { vl); } -// computes silu x/(1+exp(-x)) in single precision vector -inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { +// computes (1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_one_plus_exp_neg_x_m2(vfloat32m2_t x, int vl) { const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl); const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl); - const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); + return __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl); return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl); } +// computes sigmoid 1/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_sigmoid_m2(vfloat32m2_t x, int vl, float scale) { + const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl); + return __riscv_vfrdiv_vf_f32m2(one_plus_exp_neg_x, scale, vl); +} + #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 08383edb40..db435705c0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -52,6 +52,7 @@ #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" +#include "ggml-cuda/lerp.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" @@ -2504,6 +2505,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_RELU: ggml_cuda_op_relu(ctx, dst); break; + case GGML_UNARY_OP_RELU_SQR: + ggml_cuda_op_relu_sqr(ctx, dst); + break; case GGML_UNARY_OP_SIGMOID: ggml_cuda_op_sigmoid(ctx, dst); break; @@ -2727,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; + case GGML_OP_LERP: + ggml_cuda_op_lerp(ctx, dst); + break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); break; @@ -4522,6 +4529,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_RELU_SQR: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: @@ -4839,6 +4847,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: + case GGML_OP_LERP: return true; case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); diff --git a/ggml/src/ggml-cuda/lerp.cu b/ggml/src/ggml-cuda/lerp.cu new file mode 100644 index 0000000000..3b7eae6090 --- /dev/null +++ b/ggml/src/ggml-cuda/lerp.cu @@ -0,0 +1,283 @@ +#include "lerp.cuh" +#include + +template +static __global__ void k_lerp( + const src0_t * src0, + const src1_t * src1, + const src2_t * src2, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne03, + const uint3 ne13, + const uint3 ne21, + const uint3 ne22, + const int s1, + const int s2, + const int s3, + const int s01, + const int s02, + const int s03, + const int s11, + const int s12, + const int s13, + const int s21, + const int s22, + const int s23) { + + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); + + if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { + return; + } + + // src0/src1 broadcast in dim3 + const uint32_t i03 = fastmodulo(i3, ne03); + const uint32_t i13 = fastmodulo(i3, ne13); + + // src2 broadcast in dim1, dim2 + const uint32_t i21 = fastmodulo(i1, ne21); + const uint32_t i22 = fastmodulo(i2, ne22); + + const size_t i_src0 = i03*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i2*s12 + i1*s11; + const size_t i_src2 = i3*s23 + i22*s22 + i21*s21; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + const src2_t * src2_row = src2 + i_src2; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const float v0 = (float) src0_row[i0]; + const float v1 = (float) src1_row[i0]; + const float v2 = (float) src2_row[i0]; + + dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2); + } +} + +template +static __global__ void k_lerp_unravel( + const src0_t * src0, + const src1_t * src1, + const src2_t * src2, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne03, + const uint3 ne13, + const uint3 ne21, + const uint3 ne22, + const int s1, + const int s2, + const int s3, + const int s01, + const int s02, + const int s03, + const int s11, + const int s12, + const int s13, + const int s21, + const int s22, + const int s23) { + + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; + + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { + return; + } + + // src0/src1 broadcast in dim3 + const int i03 = fastmodulo(i3, ne03); + const int i13 = fastmodulo(i3, ne13); + + // src2 broadcast in dim1, dim2 + const int i21 = fastmodulo(i1, ne21); + const int i22 = fastmodulo(i2, ne22); + + const size_t i_src0 = i03*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i2*s12 + i1*s11; + const size_t i_src2 = i3*s23 + i22*s22 + i21*s21; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + const src2_t * src2_row = src2 + i_src2; + dst_t * dst_row = dst + i_dst; + + const float v0 = (float) src0_row[i0]; + const float v1 = (float) src1_row[i0]; + const float v2 = (float) src2_row[i0]; + + // dst = src0 + (src1 - src0) * src2 + dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2); +} + +template +static void launch_lerp( + const ggml_tensor * src0, + const ggml_tensor * src1, + const ggml_tensor * src2, + ggml_tensor * dst, + const src0_t * src0_dd, + const src1_t * src1_dd, + const src2_t * src2_dd, + dst_t * dst_dd, + cudaStream_t stream) { + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + GGML_ASSERT(ne3 % ne03 == 0); + GGML_ASSERT(ne3 % ne13 == 0); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s21 = nb21 / sizeof(src2_t); + size_t s22 = nb22 / sizeof(src2_t); + size_t s23 = nb23 / sizeof(src2_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(nb20 % sizeof(src2_t) == 0); + GGML_ASSERT(nb21 % sizeof(src2_t) == 0); + GGML_ASSERT(nb22 % sizeof(src2_t) == 0); + GGML_ASSERT(nb23 % sizeof(src2_t) == 0); + + const int block_size = CUDA_LERP_BLOCK_SIZE; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums( + (hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03); + const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13); + const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21); + const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22); + + if (block_nums.z > 65535 || block_nums.y > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + + k_lerp_unravel + <<>>( + src0_dd, src1_dd, src2_dd, dst_dd, + ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, + prod_012, prod_01, + ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv, + s1, s2, s3, + s01, s02, s03, + s11, s12, s13, + s21, s22, s23); + } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); + + k_lerp + <<>>( + src0_dd, src1_dd, src2_dd, dst_dd, + ne0, ne1, ne2, ne3_fastdiv, + ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv, + s1, s2, s3, + s01, s02, s03, + s11, s12, s13, + s21, s22, s23); + } +} + +void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + if (src2->type == GGML_TYPE_F32) { + launch_lerp( + src0, src1, src2, dst, + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + (float *) dst->data, + stream); + } else if (src2->type == GGML_TYPE_F16) { + launch_lerp( + src0, src1, src2, dst, + (const float *) src0->data, + (const float *) src1->data, + (const half *) src2->data, + (float *) dst->data, + stream); + } else { + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/lerp.cuh b/ggml/src/ggml-cuda/lerp.cuh new file mode 100644 index 0000000000..c504e82f86 --- /dev/null +++ b/ggml/src/ggml-cuda/lerp.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_LERP_BLOCK_SIZE 256 + +void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4..560fade72d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) { return fmaxf(x, 0); } +static __device__ __forceinline__ float op_relu_sqr(float x) { + float r = fmaxf(x, 0); + return r * r; +} + static __device__ __forceinline__ float op_sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } @@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e569..769341ac44 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/wkv.cu b/ggml/src/ggml-cuda/wkv.cu index d2fced705e..67e4ae9221 100644 --- a/ggml/src/ggml-cuda/wkv.cu +++ b/ggml/src/ggml-cuda/wkv.cu @@ -65,7 +65,9 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const } } -template +constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + +template static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) { const int tid = threadIdx.x; const int bid = blockIdx.x; @@ -89,7 +91,7 @@ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, cons for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { __syncthreads(); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? __expf(w_scale / (1.0f + __expf(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; @@ -179,9 +181,11 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const float * s_d = (const float *)dst->src[6]->data; const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; + const int64_t H = dst->src[4]->ne[1]; + + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; float * dst_d = (float *)dst->data; @@ -192,8 +196,16 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); if (C / H == CUDA_WKV_BLOCK_SIZE) { - rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + if (fuse_exp) { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } } else { - rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + if (fuse_exp) { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7f4cfbba22..51cde64848 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1527,6 +1527,8 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + const int32_t fuse_exp = ((const int32_t *) op->op_params)[0]; + const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; const int64_t T = op->src[0]->ne[2]; const int64_t C = op->ne[0]; @@ -1551,6 +1553,9 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_bytes (enc, (void *) &fuse_exp, sizeof(fuse_exp), ida++); + } ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 17e358d1a8..dbbbf44bd5 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2657,6 +2657,7 @@ kernel void kernel_rwkv_wkv7_f32( constant uint & T, constant uint & C, constant uint & H, + constant uint & fuse_exp, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2666,6 +2667,8 @@ kernel void kernel_rwkv_wkv7_f32( const uint head_id = tgpig.x % H; const uint tid = tpitg.x; + constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + if (batch_id >= B || head_id >= H) { return; } @@ -2692,7 +2695,7 @@ kernel void kernel_rwkv_wkv7_f32( for (uint t = start_t; t < end_t; t += C) { threadgroup_barrier(mem_flags::mem_threadgroup); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? exp(w_scale / (1 + exp(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index c10e2f7645..f8ead16d56 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -96,7 +96,9 @@ static void rwkv_wkv6_f32_kernel( } } -template +constexpr float w_scale = -0.6065306597f; // -exp(-0.5) + +template static void rwkv_wkv7_f32_kernel( const int B, const int T, const int C, const int H, const float* r, const float* w, const float* k, const float* v, @@ -132,7 +134,7 @@ static void rwkv_wkv7_f32_kernel( item_ct1.barrier(sycl::access::fence_space::local_space); _r[tid] = r[t]; - _w[tid] = w[t]; + _w[tid] = fuse_exp ? sycl::native::exp(w_scale / (1.0f + sycl::native::exp(-w[t]))) : w[t]; _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; @@ -247,9 +249,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float* dst_d = (float*)dst->data; const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; + const int64_t T = dst->src[4]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; + const int64_t H = dst->src[4]->ne[1]; + + const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0]; GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); GGML_ASSERT(C % H == 0); @@ -264,30 +268,61 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Submit kernel if (C / H == WKV_BLOCK_SIZE) { - stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + if (fuse_exp) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rwkv_wkv7_f32_kernel( - B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, - item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() - ); + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); }); - }); + }); + } } else { - stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + if (fuse_exp) { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rwkv_wkv7_f32_kernel( - B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, - item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() - ); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } else { + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv7_f32_kernel( + B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + } } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a99375c088..8cfe636773 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -789,6 +789,7 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_rwkv_wkv7_f32_fuse_exp; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -4362,6 +4363,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32_fuse_exp, "rwkv_wkv7_f32_fuse_exp", rwkv_wkv7_f32_fuse_exp_len, rwkv_wkv7_f32_fuse_exp_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); @@ -9143,6 +9145,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RWKV_WKV7: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (dst->op_params[0] == 1) { + return ctx->device->pipeline_rwkv_wkv7_f32_fuse_exp; + } return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; @@ -9899,10 +9904,62 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co }); } -static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) { - GGML_ASSERT(version == 6 || version == 7); - int num_srcs = version == 6 ? 6 : 7; +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + const vk_op_rwkv_wkv6_push_constants pc = { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }; + + const int num_srcs = 6; + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < num_srcs; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + const vk_op_rwkv_wkv7_push_constants pc = { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }; + + const int num_srcs = 7; for (int i = 0; i < num_srcs; i++) { GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); } @@ -9926,54 +9983,9 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx 1 }; - if (version == 6) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, elements); - } else if (version == 7) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, - pc, elements); - } else { - // shouldn't happen - GGML_ASSERT(false); - } -} - -static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const size_t seq_length = dst->src[0]->ne[2]; - const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[1]; - const size_t n_seqs = dst->src[5]->ne[1]; - - ggml_vk_op_f32_wkv( - ctx, subctx, dst, - { - (uint32_t)n_seqs, - (uint32_t)seq_length, - (uint32_t)n_embed, - (uint32_t)n_heads, - }, - 6 - ); -} - -static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const size_t seq_length = dst->src[0]->ne[2]; - const size_t n_embed = dst->ne[0]; - const size_t n_heads = dst->src[0]->ne[1]; - const size_t n_seqs = dst->src[6]->ne[1]; - - ggml_vk_op_f32_wkv( - ctx, subctx, dst, - { - (uint32_t)n_seqs, - (uint32_t)seq_length, - (uint32_t)n_embed, - (uint32_t)n_heads, - }, - 7 - ); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { @@ -15734,7 +15746,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], - src_clone[4], src_clone[5], src_clone[6]); + src_clone[4], src_clone[5], src_clone[6], (uint32_t) tensor->op_params[0]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bbdbf9dcaa..2e7550912e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -967,7 +967,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "0"}})); + string_to_spv("rwkv_wkv7_f32_fuse_exp", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "1"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp index 88c1c02b32..fafa1346fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -10,6 +10,7 @@ layout(push_constant) uniform Parameters { uint T; uint C; uint H; + uint fuse_exp; }; layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; @@ -45,10 +46,18 @@ void main() { const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; +#if FUSE_EXP + const float w_scale = -0.6065306597f; // -exp(-0.5) +#endif + for (uint t = start_t; t < end_t; t += C) { barrier(); _r[tid] = r[t]; +#if FUSE_EXP + _w[tid] = exp(w_scale / (1 + exp(-w[t]))); +#else _w[tid] = w[t]; +#endif _k[tid] = k[t]; _a[tid] = a[t]; _b[tid] = b[t]; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1725ad1654..7d8ab774b0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "LERP", "UNARY", @@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "x+(y-x)*t", "unary(x)", @@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1168,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TANH", "ELU", "RELU", + "RELU_SQR", "SIGMOID", "GELU", "GELU_QUICK", @@ -1185,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22"); +static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2643,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } +// ggml_relu_sqr + +struct ggml_tensor * ggml_relu_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR); +} + +struct ggml_tensor * ggml_relu_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR); +} + // ggml_leaky_relu struct ggml_tensor * ggml_leaky_relu( @@ -5704,7 +5721,8 @@ struct ggml_tensor * ggml_rwkv_wkv7( struct ggml_tensor * v, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * state) { + struct ggml_tensor * state, + bool fuse_exp) { GGML_ASSERT(ggml_is_contiguous(r)); GGML_ASSERT(ggml_is_contiguous(w)); GGML_ASSERT(ggml_is_contiguous(k)); @@ -5713,14 +5731,16 @@ struct ggml_tensor * ggml_rwkv_wkv7( GGML_ASSERT(ggml_is_contiguous(b)); GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S = k->ne[0]; - const int64_t H = k->ne[1]; - const int64_t n_tokens = k->ne[2]; + const int64_t S = a->ne[0]; + const int64_t H = a->ne[1]; + const int64_t n_tokens = a->ne[2]; const int64_t n_seqs = state->ne[1]; + const int64_t n_embd = S * H; { - GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens); + GGML_ASSERT(w->ne[0] == n_embd && w->ne[1] == n_tokens); GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens); - GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(v->ne[0] == n_embd && v->ne[1] == n_tokens); GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens); GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens); GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); @@ -5730,6 +5750,9 @@ struct ggml_tensor * ggml_rwkv_wkv7( const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t fuse_exp_i = fuse_exp ? 1 : 0; + ggml_set_op_params(result, &fuse_exp_i, sizeof(fuse_exp_i)); + result->op = GGML_OP_RWKV_WKV7; result->src[0] = r; result->src[1] = w; @@ -5742,6 +5765,34 @@ struct ggml_tensor * ggml_rwkv_wkv7( return result; } +// ggml_lerp + +struct ggml_tensor * ggml_lerp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * t) { + // assume a and b are the same shape for now + GGML_ASSERT(ggml_are_same_shape(a, b)); + + GGML_ASSERT(t->ne[0] == a->ne[0]); + GGML_ASSERT(a->ne[1] % t->ne[1] == 0); + GGML_ASSERT(a->ne[2] % t->ne[2] == 0); + + // a/b can broadcast to t at dim3 for rwkv7 + GGML_ASSERT(t->ne[3] % a->ne[3] == 0); + + const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_LERP; + result->src[0] = a; + result->src[1] = b; + result->src[2] = t; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 16d42c4ae3..783539981a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1001,10 +1001,7 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_RELU_SQR: { - cur = ggml_relu(ctx0, cur); - cb(cur, "ffn_relu", il); - - cur = ggml_sqr(ctx0, cur); + cur = ggml_relu_sqr(ctx0, cur); cb(cur, "ffn_sqr(relu)", il); } break; case LLM_FFN_SWIGLU: @@ -1307,8 +1304,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { - cur = ggml_relu(ctx0, cur); - cur = ggml_sqr(ctx0, cur); + cur = ggml_relu_sqr(ctx0, cur); cb(cur, "ffn_moe_relu_sqr", il); } break; default: diff --git a/src/models/models.h b/src/models/models.h index 3a44f7f140..1c9b40b01e 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer, ggml_tensor * cur, ggml_tensor * x_prev, - llm_arch arch) const; + llm_arch arch, + int il) const; ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp, ggml_tensor * cur, ggml_tensor * x_prev, diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..085cc68523 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer, ggml_tensor * cur, ggml_tensor * x_prev, - llm_arch arch) const { - ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + llm_arch arch, + int il) const { switch (arch) { case LLM_ARCH_RWKV7: { - ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + // cur + (x_prev - cur) * layer->channel_mix_lerp_k + ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k); - ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk))); - - cur = build_lora_mm(layer->channel_mix_value, k); + cur = build_ffn( + xk, + layer->channel_mix_key, nullptr, nullptr, // up + nullptr, nullptr, nullptr, // gate + layer->channel_mix_value, nullptr, nullptr, // down + nullptr, + LLM_FFN_RELU_SQR, + LLM_FFN_SEQ, + il + ); } break; default: @@ -46,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; - ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); - ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5); - sx = ggml_repeat(ctx0, sx, dummy); - - ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur); + ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused); ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); @@ -65,7 +69,6 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in ggml_tensor * w = ggml_add( ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))), layer.time_mix_w0); - w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531)); ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); @@ -95,14 +98,12 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka)); r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); - w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); - v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); - ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); + ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state, true); cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); @@ -122,6 +123,8 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in } else { cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); } + + v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); ggml_tensor * rk = ggml_sum_rows( ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count))); cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens)); diff --git a/src/models/rwkv7.cpp b/src/models/rwkv7.cpp index 5caf6553df..4faafa9d64 100644 --- a/src/models/rwkv7.cpp +++ b/src/models/rwkv7.cpp @@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids); x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids); } - cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7); + cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il); cur = ggml_add(ctx0, cur, ffn_inp); cur = build_cvec(cur, il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 411467e968..d8bbe8d5e2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3681,17 +3681,47 @@ struct test_rwkv_wkv7 : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t n_tokens = n_seq_tokens * n_seqs; + const int64_t n_embd = head_count * head_size; ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); - ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * w = ggml_new_tensor(ctx, type, 2, std::vector{ n_embd, n_tokens }.data()); ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); - ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * v = ggml_new_tensor(ctx, type, 2, std::vector{ n_embd, n_tokens }.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); // Outputs may become NaN with long seqlen without these normalization a = ggml_l2_norm(ctx, a, 1e-7F); b = ggml_l2_norm(ctx, b, 1e-7F); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); - ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s); + ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s, true); + return out; + } +}; + +// GGML_OP_LERP +struct test_lerp : public test_case { + const ggml_type type_t; + const ggml_type type_a; + const ggml_type type_b; + const std::array ne0; + const std::array ne1; + const std::array ne2; + + std::string vars() override { + return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2); + } + + test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + ggml_type type_t = GGML_TYPE_F32, + std::array ne0 = {10, 10, 1, 1}, + std::array ne1 = {10, 10, 1, 1}, + std::array ne2 = {10, 10, 1, 1}) + : type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data()); + ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data()); + ggml_tensor * out = ggml_lerp(ctx, a, b, c); return out; } }; @@ -7612,6 +7642,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6})); + #if 0 // > 4GB A matrix. Too slow to be enabled by default. test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));