ggml: add `lerp` op for linear interpolation in rwkv7
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
c389dc9e8a
commit
a98cac62c9
|
|
@ -551,6 +551,7 @@ extern "C" {
|
|||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_LERP,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
|
@ -2461,6 +2462,14 @@ extern "C" {
|
|||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// a + (b - a) * t
|
||||
// used in rwkv7
|
||||
GGML_API struct ggml_tensor * ggml_lerp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * t);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
|
|
|||
|
|
@ -2018,6 +2018,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_LERP:
|
||||
{
|
||||
ggml_compute_forward_lerp(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
|
|
@ -2179,6 +2184,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_LERP:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -10099,6 +10099,177 @@ void ggml_compute_forward_rwkv_wkv7(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_lerp_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(float));
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
|
||||
GGML_ASSERT(ne23 % ne03 == 0);
|
||||
GGML_ASSERT(ne23 % ne13 == 0);
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int i03 = i3 % ne03;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const int i21 = i1 % ne21;
|
||||
const int i22 = i2 % ne22;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
|
||||
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
|
||||
const float * src2_ptr = (const float *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
const float s0 = src0_ptr[i0];
|
||||
const float s1 = src1_ptr[i0];
|
||||
const float s2 = src2_ptr[i0];
|
||||
|
||||
dst_ptr[i0] = s0 + (s1 - s0) * s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_lerp_f32_f32_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(ggml_fp16_t));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(ggml_fp16_t));
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
|
||||
GGML_ASSERT(ne23 % ne03 == 0);
|
||||
GGML_ASSERT(ne23 % ne13 == 0);
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int i03 = i3 % ne03;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const int i21 = i1 % ne21;
|
||||
const int i22 = i2 % ne22;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
const float * src0_ptr = (const float *) ((char *) src0->data + i03*nb03 + i2*nb02 + i1*nb01);
|
||||
const float * src1_ptr = (const float *) ((char *) src1->data + i13*nb13 + i2*nb12 + i1*nb11);
|
||||
const ggml_fp16_t * src2_ptr = (const ggml_fp16_t *) ((char *) src2->data + i3*nb23 + i22*nb22 + i21*nb21);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
const float s0 = src0_ptr[i0];
|
||||
const float s1 = src1_ptr[i0];
|
||||
const float s2 = GGML_FP16_TO_FP32(src2_ptr[i0]);
|
||||
|
||||
dst_ptr[i0] = s0 + (s1 - s0) * s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_lerp(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
if (src2->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_lerp_f32(params, dst);
|
||||
} else if (src2->type == GGML_TYPE_F16) {
|
||||
ggml_compute_forward_lerp_f32_f32_f16(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_custom1
|
||||
|
||||
void ggml_compute_forward_map_custom1(
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
|
|||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_lerp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@
|
|||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/lerp.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
|
|
@ -2726,6 +2727,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LERP:
|
||||
ggml_cuda_op_lerp(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4635,6 +4639,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LERP:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,283 @@
|
|||
#include "lerp.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static __global__ void k_lerp(
|
||||
const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
const src2_t * src2,
|
||||
dst_t * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const uint3 ne3,
|
||||
const uint3 ne03,
|
||||
const uint3 ne13,
|
||||
const uint3 ne21,
|
||||
const uint3 ne22,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const int s21,
|
||||
const int s22,
|
||||
const int s23) {
|
||||
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
|
||||
|
||||
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
// src0/src1 broadcast in dim3
|
||||
const uint32_t i03 = fastmodulo(i3, ne03);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
// src2 broadcast in dim1, dim2
|
||||
const uint32_t i21 = fastmodulo(i1, ne21);
|
||||
const uint32_t i22 = fastmodulo(i2, ne22);
|
||||
|
||||
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
|
||||
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src2_t * src2_row = src2 + i_src2;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const float v0 = (float) src0_row[i0];
|
||||
const float v1 = (float) src1_row[i0];
|
||||
const float v2 = (float) src2_row[i0];
|
||||
|
||||
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static __global__ void k_lerp_unravel(
|
||||
const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
const src2_t * src2,
|
||||
dst_t * dst,
|
||||
const uint3 ne0,
|
||||
const uint3 ne1,
|
||||
const uint3 ne2,
|
||||
const uint32_t ne3,
|
||||
const uint3 prod_012,
|
||||
const uint3 prod_01,
|
||||
const uint3 ne03,
|
||||
const uint3 ne13,
|
||||
const uint3 ne21,
|
||||
const uint3 ne22,
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const int s21,
|
||||
const int s22,
|
||||
const int s23) {
|
||||
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const uint32_t i3 = fastdiv(i, prod_012);
|
||||
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
|
||||
const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
|
||||
const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
|
||||
|
||||
if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// src0/src1 broadcast in dim3
|
||||
const int i03 = fastmodulo(i3, ne03);
|
||||
const int i13 = fastmodulo(i3, ne13);
|
||||
|
||||
// src2 broadcast in dim1, dim2
|
||||
const int i21 = fastmodulo(i1, ne21);
|
||||
const int i22 = fastmodulo(i2, ne22);
|
||||
|
||||
const size_t i_src0 = i03*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i2*s12 + i1*s11;
|
||||
const size_t i_src2 = i3*s23 + i22*s22 + i21*s21;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src2_t * src2_row = src2 + i_src2;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const float v0 = (float) src0_row[i0];
|
||||
const float v1 = (float) src1_row[i0];
|
||||
const float v2 = (float) src2_row[i0];
|
||||
|
||||
// dst = src0 + (src1 - src0) * src2
|
||||
dst_row[i0] = (dst_t) (v0 + (v1 - v0) * v2);
|
||||
}
|
||||
|
||||
template <typename src0_t, typename src1_t, typename src2_t, typename dst_t>
|
||||
static void launch_lerp(
|
||||
const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
const ggml_tensor * src2,
|
||||
ggml_tensor * dst,
|
||||
const src0_t * src0_dd,
|
||||
const src1_t * src1_dd,
|
||||
const src2_t * src2_dd,
|
||||
dst_t * dst_dd,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
GGML_ASSERT(ne00 == ne10 && ne00 == ne20 && ne00 == ne0);
|
||||
GGML_ASSERT(ne01 % ne21 == 0);
|
||||
GGML_ASSERT(ne02 % ne22 == 0);
|
||||
GGML_ASSERT(ne3 % ne03 == 0);
|
||||
GGML_ASSERT(ne3 % ne13 == 0);
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne23);
|
||||
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s21 = nb21 / sizeof(src2_t);
|
||||
size_t s22 = nb22 / sizeof(src2_t);
|
||||
size_t s23 = nb23 / sizeof(src2_t);
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb20 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb21 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb22 % sizeof(src2_t) == 0);
|
||||
GGML_ASSERT(nb23 % sizeof(src2_t) == 0);
|
||||
|
||||
const int block_size = CUDA_LERP_BLOCK_SIZE;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
|
||||
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
|
||||
|
||||
dim3 block_nums(
|
||||
(hne0 + block_dims.x - 1) / block_dims.x,
|
||||
(ne1 + block_dims.y - 1) / block_dims.y,
|
||||
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
|
||||
|
||||
const uint3 ne03_fastdiv = init_fastdiv_values((uint32_t) ne03);
|
||||
const uint3 ne13_fastdiv = init_fastdiv_values((uint32_t) ne13);
|
||||
const uint3 ne21_fastdiv = init_fastdiv_values((uint32_t) ne21);
|
||||
const uint3 ne22_fastdiv = init_fastdiv_values((uint32_t) ne22);
|
||||
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
|
||||
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
|
||||
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
|
||||
|
||||
k_lerp_unravel<src0_t, src1_t, src2_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, src2_dd, dst_dd,
|
||||
ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3,
|
||||
prod_012, prod_01,
|
||||
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
|
||||
s1, s2, s3,
|
||||
s01, s02, s03,
|
||||
s11, s12, s13,
|
||||
s21, s22, s23);
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
|
||||
k_lerp<src0_t, src1_t, src2_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, src2_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3_fastdiv,
|
||||
ne03_fastdiv, ne13_fastdiv, ne21_fastdiv, ne22_fastdiv,
|
||||
s1, s2, s3,
|
||||
s01, s02, s03,
|
||||
s11, s12, s13,
|
||||
s21, s22, s23);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
if (src2->type == GGML_TYPE_F32) {
|
||||
launch_lerp<float, float, float, float>(
|
||||
src0, src1, src2, dst,
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
(float *) dst->data,
|
||||
stream);
|
||||
} else if (src2->type == GGML_TYPE_F16) {
|
||||
launch_lerp<float, float, half, float>(
|
||||
src0, src1, src2, dst,
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const half *) src2->data,
|
||||
(float *) dst->data,
|
||||
stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_LERP_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_lerp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
"SOLVE_TRI",
|
||||
"LERP",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
|
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
"A X = B, A triangular, solve X",
|
||||
"x+(y-x)*t",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
|
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -5742,6 +5744,34 @@ struct ggml_tensor * ggml_rwkv_wkv7(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_lerp
|
||||
|
||||
struct ggml_tensor * ggml_lerp(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * t) {
|
||||
// assume a and b are the same shape for now
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
|
||||
GGML_ASSERT(t->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(a->ne[1] % t->ne[1] == 0);
|
||||
GGML_ASSERT(a->ne[2] % t->ne[2] == 0);
|
||||
|
||||
// a/b can broadcast to t at dim3 for rwkv7
|
||||
GGML_ASSERT(t->ne[3] % a->ne[3] == 0);
|
||||
|
||||
const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], t->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_LERP;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = t;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const;
|
||||
llm_arch arch,
|
||||
int il) const;
|
||||
ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer *
|
|||
ggml_tensor * x_prev,
|
||||
llm_arch arch,
|
||||
int il) const {
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
switch (arch) {
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
|
||||
// cur + (x_prev - cur) * layer->channel_mix_lerp_k
|
||||
ggml_tensor * xk = ggml_lerp(ctx0, cur, x_prev, layer->channel_mix_lerp_k);
|
||||
|
||||
cur = build_ffn(
|
||||
xk,
|
||||
|
|
@ -54,11 +54,7 @@ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * in
|
|||
|
||||
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
|
||||
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
|
||||
sx = ggml_repeat(ctx0, sx, dummy);
|
||||
|
||||
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
|
||||
ggml_tensor * xxx = ggml_lerp(ctx0, cur, x_prev, layer.time_mix_lerp_fused);
|
||||
|
||||
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
|
|
|
|||
|
|
@ -3666,6 +3666,35 @@ struct test_rwkv_wkv7 : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_LERP
|
||||
struct test_lerp : public test_case {
|
||||
const ggml_type type_t;
|
||||
const ggml_type type_a;
|
||||
const ggml_type type_b;
|
||||
const std::array<int64_t, 4> ne0;
|
||||
const std::array<int64_t, 4> ne1;
|
||||
const std::array<int64_t, 4> ne2;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR6(type_a, type_b, type_t, ne0, ne1, ne2);
|
||||
}
|
||||
|
||||
test_lerp(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
ggml_type type_t = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne0 = {10, 10, 1, 1},
|
||||
std::array<int64_t, 4> ne1 = {10, 10, 1, 1},
|
||||
std::array<int64_t, 4> ne2 = {10, 10, 1, 1})
|
||||
: type_a(type_a), type_b(type_b), type_t(type_t), ne0(ne0), ne1(ne1), ne2(ne2) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne0.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type_b, 4, ne1.data());
|
||||
ggml_tensor * c = ggml_new_tensor(ctx, type_t, 4, ne2.data());
|
||||
ggml_tensor * out = ggml_lerp(ctx, a, b, c);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
|
|
@ -7508,6 +7537,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_lerp(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16, {256, 16, 1, 1}, {256, 16, 1, 1}, {256, 16, 1, 6}));
|
||||
|
||||
#if 0
|
||||
// > 4GB A matrix. Too slow to be enabled by default.
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
|
||||
|
|
|
|||
Loading…
Reference in New Issue