Merge fd990da706 into 3bc8d2cf23
This commit is contained in:
commit
d7d754e43b
|
|
@ -556,6 +556,7 @@ extern "C" {
|
|||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_LERP,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
|
@ -583,6 +584,7 @@ extern "C" {
|
|||
GGML_UNARY_OP_TANH,
|
||||
GGML_UNARY_OP_ELU,
|
||||
GGML_UNARY_OP_RELU,
|
||||
GGML_UNARY_OP_RELU_SQR,
|
||||
GGML_UNARY_OP_SIGMOID,
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
|
|
@ -1133,6 +1135,14 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_relu_sqr(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_relu_sqr_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sigmoid(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
|
@ -2443,7 +2453,8 @@ extern "C" {
|
|||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
struct ggml_tensor * state,
|
||||
bool fuse_exp);
|
||||
|
||||
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||
* without zeroes on the diagonal (i.e. invertible).
|
||||
|
|
@ -2466,6 +2477,14 @@ extern "C" {
|
|||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// a + (b - a) * t
|
||||
// used in rwkv7
|
||||
GGML_API struct ggml_tensor * ggml_lerp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * t);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
|
|
|||
|
|
@ -2019,6 +2019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_LERP:
|
||||
{
|
||||
ggml_compute_forward_lerp(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
|
|
@ -2180,6 +2185,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_LERP:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
@ -2216,6 +2222,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_RELU_SQR:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
|
|
|
|||
|
|
@ -9458,6 +9458,10 @@ void ggml_compute_forward_unary(
|
|||
{
|
||||
ggml_compute_forward_relu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU_SQR:
|
||||
{
|
||||
ggml_compute_forward_relu_sqr(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
{
|
||||
ggml_compute_forward_sigmoid(params, dst);
|
||||
|
|
@ -10189,9 +10193,9 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
|
|||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t T = dst->src[4]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t HEADS = dst->src[4]->ne[1];
|
||||
const int64_t n_seqs = dst->src[6]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
|
|
@ -10216,6 +10220,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
float * a = (float *) dst->src[4]->data;
|
||||
float * b = (float *) dst->src[5]->data;
|
||||
|
||||
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
|
||||
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
|
||||
|
||||
int64_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
int64_t h_stride = C / HEADS;
|
||||
|
|
@ -10252,7 +10259,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
|
|
@ -10311,6 +10318,11 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
|
||||
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
|
||||
|
||||
if (fuse_exp) {
|
||||
w_vec = ggml_v_sigmoid(w_vec, w_scale);
|
||||
w_vec = ggml_v_expf(w_vec);
|
||||
}
|
||||
|
||||
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
|
||||
|
||||
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
|
|
@ -10330,7 +10342,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v[t_h_i_offset] * k_val;
|
||||
|
|
@ -10371,7 +10383,7 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float r_val = r[t_h_j_offset];
|
||||
float w_val = w[t_h_j_offset];
|
||||
float w_val = fuse_exp ? expf(w_scale / (1.0f + expf(-w[t_h_j_offset]))) : w[t_h_j_offset];
|
||||
float k_val = k[t_h_j_offset];
|
||||
float b_val = b[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
|
|
@ -10405,6 +10417,177 @@ void ggml_compute_forward_rwkv_wkv7(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_lerp_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(float));
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
|
||||
GGML_ASSERT(ne23 % ne03 == 0);
|
||||
GGML_ASSERT(ne23 % ne13 == 0);
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int i03 = i3 % ne03;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const int i21 = i1 % ne21;
|
||||
const int i22 = i2 % ne22;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
|
||||
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
|
||||
const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
const float s0 = src0_ptr[i0];
|
||||
const float s1 = src1_ptr[i0];
|
||||
const float s2 = src2_ptr[i0];
|
||||
|
||||
dst_ptr[i0] = s0 + (s1 - s0) * s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_lerp_f32_f32_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(ggml_fp16_t));
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
|
||||
GGML_ASSERT(ne23 % ne03 == 0);
|
||||
GGML_ASSERT(ne23 % ne13 == 0);
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int i03 = i3 % ne03;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const int i21 = i1 % ne21;
|
||||
const int i22 = i2 % ne22;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
|
||||
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
|
||||
const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
const float s0 = src0_ptr[i0];
|
||||
const float s1 = src1_ptr[i0];
|
||||
const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]);
|
||||
|
||||
dst_ptr[i0] = s0 + (s1 - s0) * s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_lerp(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
if (src2->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_lerp_f32(params, dst);
|
||||
} else if (src2->type == GGML_TYPE_F16) {
|
||||
ggml_compute_forward_lerp_f32_f32_f16(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom1
|
||||
|
||||
void ggml_compute_forward_map_custom1(
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
|
|||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ static inline float op_relu(float x) {
|
|||
return (x > 0.f) ? x : 0.f;
|
||||
}
|
||||
|
||||
static inline float op_relu_sqr(float x) {
|
||||
float r = (x > 0.f) ? x : 0.f;
|
||||
return r * r;
|
||||
}
|
||||
|
||||
static inline float op_sigmoid(float x) {
|
||||
return 1.f / (1.f + expf(-x));
|
||||
}
|
||||
|
|
@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor *
|
|||
unary_op<op_relu>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_relu_sqr>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sigmoid>(params, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct
|
|||
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -1129,16 +1129,27 @@ inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
|
|||
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static svfloat32_t ggml_v_one_plus_exp_neg_x(svbool_t pg, svfloat32_t x) {
|
||||
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
|
||||
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
|
||||
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
|
||||
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
|
||||
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
|
||||
return svadd_f32_x(pg, one, exp_neg_x);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
|
||||
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
|
||||
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
|
||||
inline static svfloat32_t ggml_v_sigmoid(svbool_t pg, svfloat32_t x, float scale) {
|
||||
const svfloat32_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(pg, x);
|
||||
return svdiv_f32_x(pg, svdup_n_f32_x(pg, scale), one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
|
|
@ -1168,16 +1179,27 @@ inline static float32x4_t ggml_v_expf(float32x4_t x) {
|
|||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static float32x4_t ggml_v_one_plus_exp_neg_x(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f);
|
||||
const float32x4_t neg_x = vsubq_f32(zero, x);
|
||||
const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vaddq_f32(one, exp_neg_x);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static float32x4_t ggml_v_silu(float32x4_t x) {
|
||||
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
|
||||
inline static float32x4_t ggml_v_sigmoid(float32x4_t x, float scale) {
|
||||
const float32x4_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return vdivq_f32(vdupq_n_f32(scale), one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
|
|
@ -1211,16 +1233,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
|
|||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m512 ggml_v_silu(__m512 x) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static __m512 ggml_v_one_plus_exp_neg_x(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 neg_x = _mm512_sub_ps(zero, x);
|
||||
const __m512 exp_neg_x = ggml_v_expf(neg_x);
|
||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||
return _mm512_add_ps(one, exp_neg_x);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m512 ggml_v_silu(__m512 x) {
|
||||
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
|
||||
inline static __m512 ggml_v_sigmoid(__m512 x, float scale) {
|
||||
const __m512 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm512_div_ps(_mm512_set1_ps(scale), one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
|
|
@ -1266,16 +1299,27 @@ inline static __m256 ggml_v_expf(__m256 x) {
|
|||
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m256 ggml_v_silu(__m256 x) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static __m256 ggml_v_one_plus_exp_neg_x(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
||||
const __m256 exp_neg_x = ggml_v_expf(neg_x);
|
||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||
return _mm256_add_ps(one, exp_neg_x);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m256 ggml_v_silu(__m256 x) {
|
||||
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
|
||||
inline static __m256 ggml_v_sigmoid(__m256 x, float scale) {
|
||||
const __m256 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm256_div_ps(_mm256_set1_ps(scale), one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
|
||||
|
||||
#if defined(__FMA__)
|
||||
|
|
@ -1320,16 +1364,27 @@ inline static __m128 ggml_v_expf(__m128 x) {
|
|||
_mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m128 ggml_v_silu(__m128 x) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static __m128 ggml_v_one_plus_exp_neg_x(__m128 x) {
|
||||
const __m128 one = _mm_set1_ps(1);
|
||||
const __m128 zero = _mm_setzero_ps();
|
||||
const __m128 neg_x = _mm_sub_ps(zero, x);
|
||||
const __m128 exp_neg_x = ggml_v_expf(neg_x);
|
||||
const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
|
||||
return _mm_add_ps(one, exp_neg_x);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static __m128 ggml_v_silu(__m128 x) {
|
||||
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) (with scale) in single precision vector
|
||||
inline static __m128 ggml_v_sigmoid(__m128 x, float scale) {
|
||||
const __m128 one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x(x);
|
||||
return _mm_div_ps(_mm_set1_ps(scale), one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
|
|
@ -1374,14 +1429,25 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
|
|||
vl);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
|
||||
// computes (1+exp(-x)) in single precision vector
|
||||
inline static vfloat32m2_t ggml_v_one_plus_exp_neg_x_m2(vfloat32m2_t x, int vl) {
|
||||
const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
|
||||
const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
|
||||
const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
|
||||
return __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
|
||||
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
|
||||
return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
|
||||
}
|
||||
|
||||
// computes sigmoid 1/(1+exp(-x)) in single precision vector
|
||||
inline static vfloat32m2_t ggml_v_sigmoid_m2(vfloat32m2_t x, int vl, float scale) {
|
||||
const vfloat32m2_t one_plus_exp_neg_x = ggml_v_one_plus_exp_neg_x_m2(x, vl);
|
||||
return __riscv_vfrdiv_vf_f32m2(one_plus_exp_neg_x, scale, vl);
|
||||
}
|
||||
|
||||
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
|
||||
|
||||
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@
|
|||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/lerp.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
|
|
@ -2504,6 +2505,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_UNARY_OP_RELU:
|
||||
ggml_cuda_op_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU_SQR:
|
||||
ggml_cuda_op_relu_sqr(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
ggml_cuda_op_sigmoid(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2727,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LERP:
|
||||
ggml_cuda_op_lerp(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4522,6 +4529,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_RELU_SQR:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
|
|
@ -4839,6 +4847,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LERP:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,283 @@
|
|||
#include "lerp.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static __global__ void k_lerp(
|
||||
const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
const src2_t * src2,
|
||||
dst_t * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const uint3 ne3,
|
||||
const uint3 ne03,
|
||||
const uint3 ne13,
|
||||
const uint3 ne21,
|
||||
const uint3 ne22,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const int s21,
|
||||
const int s22,
|
||||
const int s23) {
|
||||
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
|
||||
|
||||
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
// src0/src1 broadcast in dim3
|
||||
const uint32_t i03 = fastmodulo(i3, ne03);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
// src2 broadcast in dim1, dim2
|
||||
const uint32_t i21 = fastmodulo(i1, ne21);
|
||||
const uint32_t i22 = fastmodulo(i2, ne22);
|
||||
|
||||
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
|
||||
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src2_t * src2_row = src2 + i_src2;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const float v0 = (float) src0_row[i0];
|
||||
const float v1 = (float) src1_row[i0];
|
||||
const float v2 = (float) src2_row[i0];
|
||||
|
||||
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static __global__ void k_lerp_unravel(
|
||||
const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
const src2_t * src2,
|
||||
dst_t * dst,
|
||||
const uint3 ne0,
|
||||
const uint3 ne1,
|
||||
const uint3 ne2,
|
||||
const uint32_t ne3,
|
||||
const uint3 prod_012,
|
||||
const uint3 prod_01,
|
||||
const uint3 ne03,
|
||||
const uint3 ne13,
|
||||
const uint3 ne21,
|
||||
const uint3 ne22,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const int s21,
|
||||
const int s22,
|
||||
const int s23) {
|
||||
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const uint32_t i3 = fastdiv(i, prod_012);
|
||||
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
|
||||
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
|
||||
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
|
||||
|
||||
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// src0/src1 broadcast in dim3
|
||||
const int i03 = fastmodulo(i3, ne03);
|
||||
const int i13 = fastmodulo(i3, ne13);
|
||||
|
||||
// src2 broadcast in dim1, dim2
|
||||
const int i21 = fastmodulo(i1, ne21);
|
||||
const int i22 = fastmodulo(i2, ne22);
|
||||
|
||||
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
|
||||
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src2_t * src2_row = src2 + i_src2;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const float v0 = (float) src0_row[i0];
|
||||
const float v1 = (float) src1_row[i0];
|
||||
const float v2 = (float) src2_row[i0];
|
||||
|
||||
// dst = src0 + (src1 - src0) * src2
|
||||
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
|
||||
}
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static void launch_lerp(
|
||||
const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
const ggml_tensor * src2,
|
||||
ggml_tensor * dst,
|
||||
const src0_t * src0_dd,
|
||||
const src1_t * src1_dd,
|
||||
const src2_t * src2_dd,
|
||||
dst_t * dst_dd,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
GGML_ASSERT(ne3 % ne03 == 0);
|
||||
GGML_ASSERT(ne3 % ne13 == 0);
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s21 = nb21 / sizeof(src2_t);
|
||||
size_t s22 = nb22 / sizeof(src2_t);
|
||||
size_t s23 = nb23 / sizeof(src2_t);
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb20 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb21 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb22 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb23 % sizeof(src2_t) == 0);
|
||||
|
||||
const int block_size = CUDA_LERP_BLOCK_SIZE;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
|
||||
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
|
||||
|
||||
dim3 block_nums(
|
||||
(hne0 + block_dims.x - 1) / block_dims.x,
|
||||
(ne1 + block_dims.y - 1) / block_dims.y,
|
||||
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
|
||||
|
||||
const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03);
|
||||
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13);
|
||||
const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21);
|
||||
const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22);
|
||||
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
|
||||
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
|
||||
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
|
||||
|
||||
k_lerp_unravel<src0_t, src1_t, src2_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, src2_dd, dst_dd,
|
||||
ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3,
|
||||
prod_012, prod_01,
|
||||
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
|
||||
s1, s2, s3,
|
||||
s01, s02, s03,
|
||||
s11, s12, s13,
|
||||
s21, s22, s23);
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
|
||||
k_lerp<src0_t, src1_t, src2_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, src2_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3_fastdiv,
|
||||
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
|
||||
s1, s2, s3,
|
||||
s01, s02, s03,
|
||||
s11, s12, s13,
|
||||
s21, s22, s23);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
if (src2->type == GGML_TYPE_F32) {
|
||||
launch_lerp<float, float, float, float>(
|
||||
src0, src1, src2, dst,
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
(float *) dst->data,
|
||||
stream);
|
||||
} else if (src2->type == GGML_TYPE_F16) {
|
||||
launch_lerp<float, float, half, float>(
|
||||
src0, src1, src2, dst,
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const half *) src2->data,
|
||||
(float *) dst->data,
|
||||
stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_LERP_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) {
|
|||
return fmaxf(x, 0);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_relu_sqr(float x) {
|
||||
float r = fmaxf(x, 0);
|
||||
return r * r;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_sigmoid(float x) {
|
||||
return 1.0f / (1.0f + expf(-x));
|
||||
}
|
||||
|
|
@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
ggml_cuda_op_unary<op_relu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_relu_sqr>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
|
||||
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -65,7 +65,9 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
|
||||
|
||||
template <int block_size, bool fuse_exp>
|
||||
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
|
@ -89,7 +91,7 @@ static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, cons
|
|||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_w[tid] = fuse_exp ? __expf(w_scale / (1.0f + __expf(-w[t]))) : w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
|
@ -179,9 +181,11 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
const float * s_d = (const float *)dst->src[6]->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t T = dst->src[4]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
const int64_t H = dst->src[4]->ne[1];
|
||||
|
||||
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
|
|
@ -192,8 +196,16 @@ void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
if (fuse_exp) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
if (fuse_exp) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, true><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2, false><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1527,6 +1527,8 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t fuse_exp = ((const int32_t *) op->op_params)[0];
|
||||
|
||||
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
||||
const int64_t T = op->src[0]->ne[2];
|
||||
const int64_t C = op->ne[0];
|
||||
|
|
@ -1551,6 +1553,9 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
|
||||
ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
|
||||
ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
|
||||
if (op->op == GGML_OP_RWKV_WKV7) {
|
||||
ggml_metal_encoder_set_bytes (enc, (void *) &fuse_exp, sizeof(fuse_exp), ida++);
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
|
||||
|
||||
|
|
|
|||
|
|
@ -2657,6 +2657,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
|||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
constant uint & fuse_exp,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
|
@ -2666,6 +2667,8 @@ kernel void kernel_rwkv_wkv7_f32(
|
|||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -2692,7 +2695,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
|||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_w[tid] = fuse_exp ? exp(w_scale / (1 + exp(-w[t]))) : w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
|
|
|||
|
|
@ -96,7 +96,9 @@ static void rwkv_wkv6_f32_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
constexpr float w_scale = -0.6065306597f; // -exp(-0.5)
|
||||
|
||||
template <int block_size, bool fuse_exp>
|
||||
static void rwkv_wkv7_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* r, const float* w, const float* k, const float* v,
|
||||
|
|
@ -132,7 +134,7 @@ static void rwkv_wkv7_f32_kernel(
|
|||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_w[tid] = fuse_exp ? sycl::native::exp(w_scale / (1.0f + sycl::native::exp(-w[t]))) : w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
|
@ -247,9 +249,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t T = dst->src[4]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
const int64_t H = dst->src[4]->ne[1];
|
||||
|
||||
const bool fuse_exp = (bool) ((int32_t *) dst->op_params)[0];
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
|
|
@ -264,30 +268,61 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
if (fuse_exp) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, true>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE, false>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
if (fuse_exp) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, true>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2, false>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -789,6 +789,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32_fuse_exp;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
|
|
@ -4362,6 +4363,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32_fuse_exp, "rwkv_wkv7_f32_fuse_exp", rwkv_wkv7_f32_fuse_exp_len, rwkv_wkv7_f32_fuse_exp_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
|
||||
|
|
@ -9143,6 +9145,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return nullptr;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
if (dst->op_params[0] == 1) {
|
||||
return ctx->device->pipeline_rwkv_wkv7_f32_fuse_exp;
|
||||
}
|
||||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
|
|
@ -9899,10 +9904,62 @@ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
|
||||
GGML_ASSERT(version == 6 || version == 7);
|
||||
int num_srcs = version == 6 ? 6 : 7;
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
const vk_op_rwkv_wkv6_push_constants pc = {
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
};
|
||||
|
||||
const int num_srcs = 6;
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||
}
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer src_buf[6] = {};
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[6]->ne[1];
|
||||
|
||||
const vk_op_rwkv_wkv7_push_constants pc = {
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
};
|
||||
|
||||
const int num_srcs = 7;
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
||||
}
|
||||
|
|
@ -9926,54 +9983,9 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
1
|
||||
};
|
||||
|
||||
if (version == 6) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
|
||||
pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
6
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[6]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_wkv(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
7
|
||||
);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
|
||||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
|
|
@ -15734,7 +15746,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
||||
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
src_clone[4], src_clone[5], src_clone[6], (uint32_t) tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = tensor->src[0]->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
|
|
|||
|
|
@ -967,7 +967,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "0"}}));
|
||||
string_to_spv("rwkv_wkv7_f32_fuse_exp", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"FUSE_EXP", "1"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ layout(push_constant) uniform Parameters {
|
|||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
uint fuse_exp;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
|
||||
|
|
@ -45,10 +46,18 @@ void main() {
|
|||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
#if FUSE_EXP
|
||||
const float w_scale = -0.6065306597f; // -exp(-0.5)
|
||||
#endif
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_r[tid] = r[t];
|
||||
#if FUSE_EXP
|
||||
_w[tid] = exp(w_scale / (1 + exp(-w[t])));
|
||||
#else
|
||||
_w[tid] = w[t];
|
||||
#endif
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
|
|
|
|||
|
|
@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
"SOLVE_TRI",
|
||||
"LERP",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
|
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
"A X = B, A triangular, solve X",
|
||||
"x+(y-x)*t",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
|
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -1168,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"TANH",
|
||||
"ELU",
|
||||
"RELU",
|
||||
"RELU_SQR",
|
||||
"SIGMOID",
|
||||
"GELU",
|
||||
"GELU_QUICK",
|
||||
|
|
@ -1185,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"TRUNC",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23");
|
||||
|
||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||
"REGLU",
|
||||
|
|
@ -2643,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace(
|
|||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
||||
}
|
||||
|
||||
// ggml_relu_sqr
|
||||
|
||||
struct ggml_tensor * ggml_relu_sqr(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_relu_sqr_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR);
|
||||
}
|
||||
|
||||
// ggml_leaky_relu
|
||||
|
||||
struct ggml_tensor * ggml_leaky_relu(
|
||||
|
|
@ -5704,7 +5721,8 @@ struct ggml_tensor * ggml_rwkv_wkv7(
|
|||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state) {
|
||||
struct ggml_tensor * state,
|
||||
bool fuse_exp) {
|
||||
GGML_ASSERT(ggml_is_contiguous(r));
|
||||
GGML_ASSERT(ggml_is_contiguous(w));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
|
|
@ -5713,14 +5731,16 @@ struct ggml_tensor * ggml_rwkv_wkv7(
|
|||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t S = a->ne[0];
|
||||
const int64_t H = a->ne[1];
|
||||
const int64_t n_tokens = a->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
const int64_t n_embd = S * H;
|
||||
{
|
||||
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
|
||||
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
|
||||
GGML_ASSERT(w->ne[0] == n_embd && w->ne[1] == n_tokens);
|
||||
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == n_embd && v->ne[1] == n_tokens);
|
||||
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
|
||||
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
|
|
@ -5730,6 +5750,9 @@ struct ggml_tensor * ggml_rwkv_wkv7(
|
|||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t fuse_exp_i = fuse_exp ? 1 : 0;
|
||||
ggml_set_op_params(result, &fuse_exp_i, sizeof(fuse_exp_i));
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV7;
|
||||
result->src[0] = r;
|
||||
result->src[1] = w;
|
||||
|
|
@ -5742,6 +5765,34 @@ struct ggml_tensor * ggml_rwkv_wkv7(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_lerp
|
||||
|
||||
struct ggml_tensor * ggml_lerp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * t) {
|
||||
// assume a and b are the same shape for now
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
|
||||
GGML_ASSERT(t->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(a->ne[1] % t->ne[1] == 0);
|
||||
GGML_ASSERT(a->ne[2] % t->ne[2] == 0);
|
||||
|
||||
// a/b can broadcast to t at dim3 for rwkv7
|
||||
GGML_ASSERT(t->ne[3] % a->ne[3] == 0);
|
||||
|
||||
const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_LERP;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = t;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
|
|
|||
|
|
@ -1001,10 +1001,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|||
} break;
|
||||
case LLM_FFN_RELU_SQR:
|
||||
{
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
|
||||
cur = ggml_sqr(ctx0, cur);
|
||||
cur = ggml_relu_sqr(ctx0, cur);
|
||||
cb(cur, "ffn_sqr(relu)", il);
|
||||
} break;
|
||||
case LLM_FFN_SWIGLU:
|
||||
|
|
@ -1307,8 +1304,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
// TODO: add support for gated squared relu
|
||||
GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||
} else {
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cur = ggml_sqr(ctx0, cur);
|
||||
cur = ggml_relu_sqr(ctx0, cur);
|
||||
cb(cur, "ffn_moe_relu_sqr", il);
|
||||
} break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const;
|
||||
llm_arch arch,
|
||||
int il) const;
|
||||
ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
|
|
|
|||
|
|
@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_
|
|||
ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const {
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
llm_arch arch,
|
||||
int il) const {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
|
||||
// cur + (x_prev - cur) * layer->channel_mix_lerp_k
|
||||
ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k);
|
||||
|
||||
ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
|
||||
|
||||
cur = build_lora_mm(layer->channel_mix_value, k);
|
||||
cur = build_ffn(
|
||||
xk,
|
||||
layer->channel_mix_key, nullptr, nullptr, // up
|
||||
nullptr, nullptr, nullptr, // gate
|
||||
layer->channel_mix_value, nullptr, nullptr, // down
|
||||
nullptr,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SEQ,
|
||||
il
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
|
@ -46,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
|
|||
|
||||
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
|
||||
sx = ggml_repeat(ctx0, sx, dummy);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
|
||||
ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused);
|
||||
|
||||
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
|
|
@ -65,7 +69,6 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
|
|||
ggml_tensor * w = ggml_add(
|
||||
ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
|
||||
layer.time_mix_w0);
|
||||
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
|
||||
|
||||
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
|
||||
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
|
||||
|
|
@ -95,14 +98,12 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
|
|||
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
|
||||
|
||||
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
|
||||
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state, true);
|
||||
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
||||
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
|
|
@ -122,6 +123,8 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
|
|||
} else {
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
}
|
||||
|
||||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
ggml_tensor * rk = ggml_sum_rows(
|
||||
ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
|
||||
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
|
|||
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
|
||||
}
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
|
|
|
|||
|
|
@ -3681,17 +3681,47 @@ struct test_rwkv_wkv7 : public test_case {
|
|||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
const int64_t n_embd = head_count * head_size;
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * w = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * w = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ n_embd, n_tokens }.data());
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
// Outputs may become NaN with long seqlen without these normalization
|
||||
a = ggml_l2_norm(ctx, a, 1e-7F);
|
||||
b = ggml_l2_norm(ctx, b, 1e-7F);
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
|
||||
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s, true);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_LERP
|
||||
struct test_lerp : public test_case {
|
||||
const ggml_type type_t;
|
||||
const ggml_type type_a;
|
||||
const ggml_type type_b;
|
||||
const std::array<int64_t, 4> ne0;
|
||||
const std::array<int64_t, 4> ne1;
|
||||
const std::array<int64_t, 4> ne2;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2);
|
||||
}
|
||||
|
||||
test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
ggml_type type_t = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne0 = {10, 10, 1, 1},
|
||||
std::array<int64_t, 4> ne1 = {10, 10, 1, 1},
|
||||
std::array<int64_t, 4> ne2 = {10, 10, 1, 1})
|
||||
: type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data());
|
||||
ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data());
|
||||
ggml_tensor * out = ggml_lerp(ctx, a, b, c);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
|
@ -7612,6 +7642,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
|
||||
|
||||
#if 0
|
||||
// > 4GB A matrix. Too slow to be enabled by default.
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
|
||||
|
|
|
|||
Loading…
Reference in New Issue