diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index bd3f4a487c..f1366c4195 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -551,6 +551,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_LERP, GGML_OP_UNARY, @@ -2461,6 +2462,14 @@ extern "C" { bool lower, bool uni); + // a + (b - a) * t + // used in rwkv7 + GGML_API struct ggml_tensor * ggml_lerp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * t); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317..0f1f8fffa8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2018,6 +2018,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_LERP: + { + ggml_compute_forward_lerp(params, tensor); + } + break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2179,6 +2184,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CUMSUM: case GGML_OP_TRI: case GGML_OP_FILL: + case GGML_OP_LERP: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f4741da9d7..43e67c7131 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10099,6 +10099,177 @@ void ggml_compute_forward_rwkv_wkv7( } } +static void ggml_compute_forward_lerp_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(float)); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + + GGML_ASSERT(ne23 % ne03 == 0); + GGML_ASSERT(ne23 % ne13 == 0); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + const int nr = ggml_nrows(dst); + + const int dr = (nr + nth - 1)/nth; + + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int i03 = i3 % ne03; + const int i13 = i3 % ne13; + + const int i21 = i1 % ne21; + const int i22 = i2 % ne22; + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01); + const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11); + const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const float s0 = src0_ptr[i0]; + const float s1 = src1_ptr[i0]; + const float s2 = src2_ptr[i0]; + + dst_ptr[i0] = s0 + (s1 - s0) * s2; + } + } +} + +static void ggml_compute_forward_lerp_f32_f32_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F16); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + + GGML_ASSERT(ne23 % ne03 == 0); + GGML_ASSERT(ne23 % ne13 == 0); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int i03 = i3 % ne03; + const int i13 = i3 % ne13; + + const int i21 = i1 % ne21; + const int i22 = i2 % ne22; + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01); + const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11); + const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const float s0 = src0_ptr[i0]; + const float s1 = src1_ptr[i0]; + const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]); + + dst_ptr[i0] = s0 + (s1 - s0) * s2; + } + } +} + +void ggml_compute_forward_lerp( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src2 = dst->src[2]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (src2->type == GGML_TYPE_F32) { + ggml_compute_forward_lerp_f32(params, dst); + } else if (src2->type == GGML_TYPE_F16) { + ggml_compute_forward_lerp_f32_f32_f16(params, dst); + } else { + GGML_ABORT("fatal error"); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..c048b4b272 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bac69cdd1c..456ca40e3d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -52,6 +52,7 @@ #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" +#include "ggml-cuda/lerp.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" @@ -2726,6 +2727,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; + case GGML_OP_LERP: + ggml_cuda_op_lerp(ctx, dst); + break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); break; @@ -4635,6 +4639,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: + case GGML_OP_LERP: return true; case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); diff --git a/ggml/src/ggml-cuda/lerp.cu b/ggml/src/ggml-cuda/lerp.cu new file mode 100644 index 0000000000..3b7eae6090 --- /dev/null +++ b/ggml/src/ggml-cuda/lerp.cu @@ -0,0 +1,283 @@ +#include "lerp.cuh" +#include + +template +static __global__ void k_lerp( + const src0_t * src0, + const src1_t * src1, + const src2_t * src2, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne03, + const uint3 ne13, + const uint3 ne21, + const uint3 ne22, + const int s1, + const int s2, + const int s3, + const int s01, + const int s02, + const int s03, + const int s11, + const int s12, + const int s13, + const int s21, + const int s22, + const int s23) { + + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); + + if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { + return; + } + + // src0/src1 broadcast in dim3 + const uint32_t i03 = fastmodulo(i3, ne03); + const uint32_t i13 = fastmodulo(i3, ne13); + + // src2 broadcast in dim1, dim2 + const uint32_t i21 = fastmodulo(i1, ne21); + const uint32_t i22 = fastmodulo(i2, ne22); + + const size_t i_src0 = i03*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i2*s12 + i1*s11; + const size_t i_src2 = i3*s23 + i22*s22 + i21*s21; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + const src2_t * src2_row = src2 + i_src2; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const float v0 = (float) src0_row[i0]; + const float v1 = (float) src1_row[i0]; + const float v2 = (float) src2_row[i0]; + + dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2); + } +} + +template +static __global__ void k_lerp_unravel( + const src0_t * src0, + const src1_t * src1, + const src2_t * src2, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne03, + const uint3 ne13, + const uint3 ne21, + const uint3 ne22, + const int s1, + const int s2, + const int s3, + const int s01, + const int s02, + const int s03, + const int s11, + const int s12, + const int s13, + const int s21, + const int s22, + const int s23) { + + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; + + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { + return; + } + + // src0/src1 broadcast in dim3 + const int i03 = fastmodulo(i3, ne03); + const int i13 = fastmodulo(i3, ne13); + + // src2 broadcast in dim1, dim2 + const int i21 = fastmodulo(i1, ne21); + const int i22 = fastmodulo(i2, ne22); + + const size_t i_src0 = i03*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i2*s12 + i1*s11; + const size_t i_src2 = i3*s23 + i22*s22 + i21*s21; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + const src2_t * src2_row = src2 + i_src2; + dst_t * dst_row = dst + i_dst; + + const float v0 = (float) src0_row[i0]; + const float v1 = (float) src1_row[i0]; + const float v2 = (float) src2_row[i0]; + + // dst = src0 + (src1 - src0) * src2 + dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2); +} + +template +static void launch_lerp( + const ggml_tensor * src0, + const ggml_tensor * src1, + const ggml_tensor * src2, + ggml_tensor * dst, + const src0_t * src0_dd, + const src1_t * src1_dd, + const src2_t * src2_dd, + dst_t * dst_dd, + cudaStream_t stream) { + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0); + GGML_ASSERT(ne01 % ne21 == 0); + GGML_ASSERT(ne02 % ne22 == 0); + GGML_ASSERT(ne3 % ne03 == 0); + GGML_ASSERT(ne3 % ne13 == 0); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne23); + + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s21 = nb21 / sizeof(src2_t); + size_t s22 = nb22 / sizeof(src2_t); + size_t s23 = nb23 / sizeof(src2_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(nb20 % sizeof(src2_t) == 0); + GGML_ASSERT(nb21 % sizeof(src2_t) == 0); + GGML_ASSERT(nb22 % sizeof(src2_t) == 0); + GGML_ASSERT(nb23 % sizeof(src2_t) == 0); + + const int block_size = CUDA_LERP_BLOCK_SIZE; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums( + (hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03); + const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13); + const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21); + const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22); + + if (block_nums.z > 65535 || block_nums.y > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + + k_lerp_unravel + <<>>( + src0_dd, src1_dd, src2_dd, dst_dd, + ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, + prod_012, prod_01, + ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv, + s1, s2, s3, + s01, s02, s03, + s11, s12, s13, + s21, s22, s23); + } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); + + k_lerp + <<>>( + src0_dd, src1_dd, src2_dd, dst_dd, + ne0, ne1, ne2, ne3_fastdiv, + ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv, + s1, s2, s3, + s01, s02, s03, + s11, s12, s13, + s21, s22, s23); + } +} + +void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + if (src2->type == GGML_TYPE_F32) { + launch_lerp( + src0, src1, src2, dst, + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + (float *) dst->data, + stream); + } else if (src2->type == GGML_TYPE_F16) { + launch_lerp( + src0, src1, src2, dst, + (const float *) src0->data, + (const float *) src1->data, + (const half *) src2->data, + (float *) dst->data, + stream); + } else { + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/lerp.cuh b/ggml/src/ggml-cuda/lerp.cuh new file mode 100644 index 0000000000..c504e82f86 --- /dev/null +++ b/ggml/src/ggml-cuda/lerp.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_LERP_BLOCK_SIZE 256 + +void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aa1188d44c..990f3ffa0e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "LERP", "UNARY", @@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "x+(y-x)*t", "unary(x)", @@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5742,6 +5744,34 @@ struct ggml_tensor * ggml_rwkv_wkv7( return result; } +// ggml_lerp + +struct ggml_tensor * ggml_lerp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * t) { + // assume a and b are the same shape for now + GGML_ASSERT(ggml_are_same_shape(a, b)); + + GGML_ASSERT(t->ne[0] == a->ne[0]); + GGML_ASSERT(a->ne[1] % t->ne[1] == 0); + GGML_ASSERT(a->ne[2] % t->ne[2] == 0); + + // a/b can broadcast to t at dim3 for rwkv7 + GGML_ASSERT(t->ne[3] % a->ne[3] == 0); + + const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_LERP; + result->src[0] = a; + result->src[1] = b; + result->src[2] = t; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( diff --git a/src/models/models.h b/src/models/models.h index 72b2b760c6..1513acec00 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer, ggml_tensor * cur, ggml_tensor * x_prev, - llm_arch arch) const; + llm_arch arch, + int il) const; ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp, ggml_tensor * cur, ggml_tensor * x_prev, diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index d975ced1f0..085cc68523 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -9,11 +9,11 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * ggml_tensor * x_prev, llm_arch arch, int il) const { - ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); switch (arch) { case LLM_ARCH_RWKV7: { - ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + // cur + (x_prev - cur) * layer->channel_mix_lerp_k + ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k); cur = build_ffn( xk, @@ -54,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; - ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); - ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5); - sx = ggml_repeat(ctx0, sx, dummy); - - ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur); + ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused); ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e6a72a29c5..9ecee15220 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3666,6 +3666,35 @@ struct test_rwkv_wkv7 : public test_case { } }; +// GGML_OP_LERP +struct test_lerp : public test_case { + const ggml_type type_t; + const ggml_type type_a; + const ggml_type type_b; + const std::array ne0; + const std::array ne1; + const std::array ne2; + + std::string vars() override { + return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2); + } + + test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + ggml_type type_t = GGML_TYPE_F32, + std::array ne0 = {10, 10, 1, 1}, + std::array ne1 = {10, 10, 1, 1}, + std::array ne2 = {10, 10, 1, 1}) + : type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data()); + ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data()); + ggml_tensor * out = ggml_lerp(ctx, a, b, c); + return out; + } +}; + // GGML_OP_MUL_MAT struct test_mul_mat : public test_case { const ggml_type type_a; @@ -7508,6 +7537,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1})); + test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6})); + #if 0 // > 4GB A matrix. Too slow to be enabled by default. test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));