ggml: add `lerp` op for linear interpolation in rwkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2026-01-11 20:59:50 +08:00
parent c389dc9e8a
commit a98cac62c9
11 changed files with 551 additions and 10 deletions

View File

@ -551,6 +551,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_LERP,
GGML_OP_UNARY,
@ -2461,6 +2462,14 @@ extern "C" {
bool lower,
bool uni);
// a + (b - a) * t
// used in rwkv7
GGML_API struct ggml_tensor * ggml_lerp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * t);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2018,6 +2018,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_LERP:
{
ggml_compute_forward_lerp(params, tensor);
}
break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2179,6 +2184,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_FILL:
case GGML_OP_LERP:
{
n_tasks = n_threads;
} break;

View File

@ -10099,6 +10099,177 @@ void ggml_compute_forward_rwkv_wkv7(
}
}
static void ggml_compute_forward_lerp_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(float));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne23 % ne03 == 0);
GGML_ASSERT(ne23 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
const int nr = ggml_nrows(dst);
const int dr = (nr + nth - 1)/nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int i03 = i3 % ne03;
const int i13 = i3 % ne13;
const int i21 = i1 % ne21;
const int i22 = i2 % ne22;
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const float s0 = src0_ptr[i0];
const float s1 = src1_ptr[i0];
const float s2 = src2_ptr[i0];
dst_ptr[i0] = s0 + (s1 - s0) * s2;
}
}
}
static void ggml_compute_forward_lerp_f32_f32_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F16);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(ggml_fp16_t));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne23 % ne03 == 0);
GGML_ASSERT(ne23 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
const int nr = ggml_nrows(dst);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int i03 = i3 % ne03;
const int i13 = i3 % ne13;
const int i21 = i1 % ne21;
const int i22 = i2 % ne22;
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const float s0 = src0_ptr[i0];
const float s1 = src1_ptr[i0];
const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]);
dst_ptr[i0] = s0 + (s1 - s0) * s2;
}
}
}
void ggml_compute_forward_lerp(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src2 = dst->src[2];
switch (src0->type) {
case GGML_TYPE_F32:
{
if (src2->type == GGML_TYPE_F32) {
ggml_compute_forward_lerp_f32(params, dst);
} else if (src2->type == GGML_TYPE_F16) {
ggml_compute_forward_lerp_f32_f32_f16(params, dst);
} else {
GGML_ABORT("fatal error");
}
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_custom1
void ggml_compute_forward_map_custom1(

View File

@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -52,6 +52,7 @@
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/lerp.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
@ -2726,6 +2727,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_LERP:
ggml_cuda_op_lerp(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -4635,6 +4639,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LERP:
return true;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);

283
ggml/src/ggml-cuda/lerp.cu Normal file
View File

@ -0,0 +1,283 @@
#include "lerp.cuh"
#include <cstdint>
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static __global__ void k_lerp(
const src0_t * src0,
const src1_t * src1,
const src2_t * src2,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne03,
const uint3 ne13,
const uint3 ne21,
const uint3 ne22,
const int s1,
const int s2,
const int s3,
const int s01,
const int s02,
const int s03,
const int s11,
const int s12,
const int s13,
const int s21,
const int s22,
const int s23) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
return;
}
// src0/src1 broadcast in dim3
const uint32_t i03 = fastmodulo(i3, ne03);
const uint32_t i13 = fastmodulo(i3, ne13);
// src2 broadcast in dim1, dim2
const uint32_t i21 = fastmodulo(i1, ne21);
const uint32_t i22 = fastmodulo(i2, ne22);
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src2_t * src2_row = src2 + i_src2;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const float v0 = (float) src0_row[i0];
const float v1 = (float) src1_row[i0];
const float v2 = (float) src2_row[i0];
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
}
}
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static __global__ void k_lerp_unravel(
const src0_t * src0,
const src1_t * src1,
const src2_t * src2,
dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne03,
const uint3 ne13,
const uint3 ne21,
const uint3 ne22,
const int s1,
const int s2,
const int s3,
const int s01,
const int s02,
const int s03,
const int s11,
const int s12,
const int s13,
const int s21,
const int s22,
const int s23) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = fastdiv(i, prod_012);
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return;
}
// src0/src1 broadcast in dim3
const int i03 = fastmodulo(i3, ne03);
const int i13 = fastmodulo(i3, ne13);
// src2 broadcast in dim1, dim2
const int i21 = fastmodulo(i1, ne21);
const int i22 = fastmodulo(i2, ne22);
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
const src2_t * src2_row = src2 + i_src2;
dst_t * dst_row = dst + i_dst;
const float v0 = (float) src0_row[i0];
const float v1 = (float) src1_row[i0];
const float v2 = (float) src2_row[i0];
// dst = src0 + (src1 - src0) * src2
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
}
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
static void launch_lerp(
const ggml_tensor * src0,
const ggml_tensor * src1,
const ggml_tensor * src2,
ggml_tensor * dst,
const src0_t * src0_dd,
const src1_t * src1_dd,
const src2_t * src2_dd,
dst_t * dst_dd,
cudaStream_t stream) {
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
GGML_ASSERT(ne01 % ne21 == 0);
GGML_ASSERT(ne02 % ne22 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
GGML_ASSERT(ne3 % ne13 == 0);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne23);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s21 = nb21 / sizeof(src2_t);
size_t s22 = nb22 / sizeof(src2_t);
size_t s23 = nb23 / sizeof(src2_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(nb20 % sizeof(src2_t) == 0);
GGML_ASSERT(nb21 % sizeof(src2_t) == 0);
GGML_ASSERT(nb22 % sizeof(src2_t) == 0);
GGML_ASSERT(nb23 % sizeof(src2_t) == 0);
const int block_size = CUDA_LERP_BLOCK_SIZE;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums(
(hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03);
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13);
const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21);
const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22);
if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
k_lerp_unravel<src0_t, src1_t, src2_t, dst_t>
<<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, src2_dd, dst_dd,
ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3,
prod_012, prod_01,
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
s1, s2, s3,
s01, s02, s03,
s11, s12, s13,
s21, s22, s23);
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
k_lerp<src0_t, src1_t, src2_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, src2_dd, dst_dd,
ne0, ne1, ne2, ne3_fastdiv,
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
s1, s2, s3,
s01, s02, s03,
s11, s12, s13,
s21, s22, s23);
}
}
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (src2->type == GGML_TYPE_F32) {
launch_lerp<float, float, float, float>(
src0, src1, src2, dst,
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
(float *) dst->data,
stream);
} else if (src2->type == GGML_TYPE_F16) {
launch_lerp<float, float, half, float>(
src0, src1, src2, dst,
(const float *) src0->data,
(const float *) src1->data,
(const half *) src2->data,
(float *) dst->data,
stream);
} else {
GGML_ABORT("fatal error");
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_LERP_BLOCK_SIZE 256
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"LERP",
"UNARY",
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"x+(y-x)*t",
"unary(x)",
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5742,6 +5744,34 @@ struct ggml_tensor * ggml_rwkv_wkv7(
return result;
}
// ggml_lerp
struct ggml_tensor * ggml_lerp(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * t) {
// assume a and b are the same shape for now
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(t->ne[0] == a->ne[0]);
GGML_ASSERT(a->ne[1] % t->ne[1] == 0);
GGML_ASSERT(a->ne[2] % t->ne[2] == 0);
// a/b can broadcast to t at dim3 for rwkv7
GGML_ASSERT(t->ne[3] % a->ne[3] == 0);
const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_LERP;
result->src[0] = a;
result->src[1] = b;
result->src[2] = t;
return result;
}
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(

View File

@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
ggml_tensor * cur,
ggml_tensor * x_prev,
llm_arch arch) const;
llm_arch arch,
int il) const;
ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
ggml_tensor * cur,
ggml_tensor * x_prev,

View File

@ -9,11 +9,11 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer *
ggml_tensor * x_prev,
llm_arch arch,
int il) const {
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
switch (arch) {
case LLM_ARCH_RWKV7:
{
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
// cur + (x_prev - cur) * layer->channel_mix_lerp_k
ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k);
cur = build_ffn(
xk,
@ -54,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
sx = ggml_repeat(ctx0, sx, dummy);
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused);
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));

View File

@ -3666,6 +3666,35 @@ struct test_rwkv_wkv7 : public test_case {
}
};
// GGML_OP_LERP
struct test_lerp : public test_case {
const ggml_type type_t;
const ggml_type type_a;
const ggml_type type_b;
const std::array<int64_t, 4> ne0;
const std::array<int64_t, 4> ne1;
const std::array<int64_t, 4> ne2;
std::string vars() override {
return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2);
}
test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
ggml_type type_t = GGML_TYPE_F32,
std::array<int64_t, 4> ne0 = {10, 10, 1, 1},
std::array<int64_t, 4> ne1 = {10, 10, 1, 1},
std::array<int64_t, 4> ne2 = {10, 10, 1, 1})
: type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data());
ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data());
ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data());
ggml_tensor * out = ggml_lerp(ctx, a, b, c);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -7508,6 +7537,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
#if 0
// > 4GB A matrix. Too slow to be enabled by default.
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));