Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin 2025-07-07 14:57:56 -04:00
commit f71635824b
40 changed files with 1850 additions and 284 deletions

View File

@ -136,6 +136,11 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;

View File

@ -557,6 +557,8 @@ extern "C" {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
GGML_GLU_OP_COUNT,
};
@ -1147,6 +1149,22 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: n columns, r rows,
// B: n columns, r rows,
GGML_API struct ggml_tensor * ggml_glu_split(
@ -1170,6 +1188,16 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,

View File

@ -67,6 +67,7 @@
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_zero.h>
#include <float.h>
#include <cmath>
@ -804,10 +805,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
nb[i] = nb[i - 1] * ne[i - 1];
}
ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
aclTensor* zero =
ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero);
return zero;
GGML_UNUSED(n_bytes);
}
/**

View File

@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
n_tasks = n_threads;
} break;

View File

@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
}
}
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_geglu_quick
static void ggml_compute_forward_geglu_quick_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_quick_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_quick_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_norm
static void ggml_compute_forward_norm_f32(
@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
} break;
case GGML_GLU_OP_GEGLU_QUICK:
{
ggml_compute_forward_geglu_quick(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");

View File

@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
}
}
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
}
}
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
}
}
#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
}
}
#else
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
}
}
#endif
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
float v = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
}
}
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
ggml_float sum = 0.0;

View File

@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_BF16:
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@ -210,6 +214,10 @@ void get_rows_cuda(
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_F16:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

View File

@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_cuda_op_geglu_quick(ctx, dst);
break;
default:
return false;
}
@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;
@ -3192,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:

View File

@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
}
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
}
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {

View File

@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -530,6 +530,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node(
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
case GGML_GLU_OP_GEGLU_QUICK:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
break;
default:
GGML_ABORT("fatal error");
}

View File

@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
@ -1258,6 +1261,50 @@ kernel void kernel_swiglu(
}
}
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_quick(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,

View File

@ -398,12 +398,13 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_scale;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_rms_norm;
cl_kernel kernel_group_norm;
@ -736,6 +737,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err));
GGML_LOG_CONT(".");
@ -753,12 +756,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_glu =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
GGML_LOG_CONT(".");
}
@ -2262,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_UNARY_OP_SIGMOID:
@ -2277,6 +2285,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
default:
return false;
@ -3864,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
int n = ggml_nelements(dst);
if (n % 4 == 0) {
kernel = backend_ctx->kernel_gelu_erf_4;
n /= 4;
} else {
kernel = backend_ctx->kernel_gelu_erf;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@ -6254,6 +6302,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_swiglu_f16;
}
break;
case GGML_GLU_OP_GEGLU_ERF:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_erf;
} else {
kernel = backend_ctx->kernel_geglu_erf_f16;
}
break;
case GGML_GLU_OP_GEGLU_QUICK:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_quick;
} else {
kernel = backend_ctx->kernel_geglu_quick_f16;
}
break;
default:
GGML_ABORT("Unsupported glu op");
}
@ -6368,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_gelu;
break;
case GGML_UNARY_OP_GELU_ERF:
if (!any_on_device) {
return false;
}
func = ggml_cl_gelu_erf;
break;
case GGML_UNARY_OP_GELU_QUICK:
if (!any_on_device) {
return false;

View File

@ -6,6 +6,7 @@
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
#define SQRT_2_INV 0.70710678118654752440084436210484f
kernel void kernel_gelu(
global float * src0,
@ -35,6 +36,32 @@ kernel void kernel_gelu_4(
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_erf(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
}
kernel void kernel_gelu_quick(
global float * src0,
ulong offset0,

View File

@ -1,7 +1,9 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
#define SQRT_2_INV 0.70710678118654752440084436210484f
//------------------------------------------------------------------------------
// geglu
@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
dst_row[i0] = silu*x1;
}
}
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------
kernel void kernel_geglu_erf(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_erf_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
//------------------------------------------------------------------------------
// geglu_quick
//------------------------------------------------------------------------------
kernel void kernel_geglu_quick(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
kernel void kernel_geglu_quick_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}

View File

@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
}
}
template<typename T>
static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
dst[i] = op_gelu_erf(x[j0]) * g[j1];
}
}
template<typename T>
static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
dst[i] = op_gelu_quick(x[j0]) * g[j1];
}
}
namespace ggml_sycl_detail {
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
});
}
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
sycl_parallel_for(main_stream,
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
}
static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
sycl_parallel_for(main_stream,
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
}
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_swiglu(ctx, dst);
}
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_erf(ctx, dst);
}
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_quick(ctx, dst);
}

View File

@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ELEMENTWISE_HPP

View File

@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_sycl_swiglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_sycl_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_sycl_geglu_quick(ctx, dst);
break;
default:
return false;
}
@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;

View File

@ -456,6 +456,8 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_silu_back_f32;
@ -499,6 +501,8 @@ struct vk_device_struct {
ggml_backend_buffer_type buffer_type;
bool disable_fusion;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
#endif
@ -634,6 +638,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@ -649,8 +654,7 @@ struct vk_flash_attn_push_constants {
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;
@ -1089,8 +1093,8 @@ static size_t vk_skip_checks;
static size_t vk_output_tensor;
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
static void ggml_vk_check_results_0(ggml_tensor * tensor);
static void ggml_vk_check_results_1(ggml_tensor * tensor);
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
#endif
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@ -2821,6 +2825,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -3503,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->idx = idx;
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
return device;
}
@ -6107,6 +6115,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;
const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
@ -6174,7 +6183,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
@ -6344,17 +6353,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1, nem2,
nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1,
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx);
@ -6575,6 +6586,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
default:
break;
}
@ -7643,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
@ -8874,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@ -8919,6 +8933,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
default:
return false;
@ -9133,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// fused rms_norm + mul
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
} else {
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
}
break;
case GGML_OP_RMS_NORM_BACK:
@ -9166,6 +9182,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break;
default:
@ -9293,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->compute_ctx.reset();
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
if (!ok) {
if (node->op == GGML_OP_UNARY) {
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@ -9308,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return true;
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
GGML_UNUSED(cgraph);
ggml_backend_buffer * buf = nullptr;
switch (tensor->op) {
@ -9384,6 +9403,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
break;
default:
@ -9416,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
// Only run if ctx hasn't been submitted yet
if (!subctx->seqs.empty()) {
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_0(tensor);
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
use_fence = true;
#endif
@ -9436,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
ggml_vk_wait_for_fence(ctx);
}
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_1(tensor);
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
#endif
}
@ -9883,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
}
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
// additional constraints specific to this fusion
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@ -9896,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@ -9966,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
@ -10194,6 +10246,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@ -10287,12 +10341,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
@ -10747,11 +10795,21 @@ void * comp_result;
size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0;
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
bool fused_rms_norm_mul = false;
int rms_norm_idx = -1;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
check_counter++;
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
@ -10779,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
for (int i = 0; i < 6; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
switch (i) {
case 0: srci = rms_norm->src[0]; break;
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
default: continue;
}
}
if (srci == nullptr) {
continue;
}
@ -10836,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
if (fused_rms_norm_mul) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
} else {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) {
@ -11027,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
GGML_ABORT("fatal error");
}
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph, tensor_clone);
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@ -11053,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
}
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
}

View File

@ -101,8 +101,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@ -149,7 +149,7 @@ void main() {
}
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;

View File

@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;
@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}

View File

@ -126,8 +126,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@ -182,7 +182,7 @@ void main() {
barrier();
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;

View File

@ -131,8 +131,8 @@ void main() {
}
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
@ -153,7 +153,7 @@ void main() {
}
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

View File

@ -0,0 +1,27 @@
#version 450
#include "glu_head.comp"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.comp"

View File

@ -0,0 +1,11 @@
#version 450
#include "glu_head.comp"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.comp"

View File

@ -500,10 +500,9 @@ void main() {
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx % 128) / 4;
const int i8 = 2 * int(idx % 4);
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 32;
const float d = float(data_a[ib].d);
const uint qh = data_a[ib].qh[ib32];
@ -512,22 +511,16 @@ void main() {
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
[[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4;
const uint ib = idx / 32; // 8 values per idx
const uint ib8 = idx % 32;
const uint ib16 = ib8 / 2;
const int i8 = 2 * int(idx % 4);
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@ -538,21 +531,17 @@ void main() {
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
[[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4;
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 4;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@ -562,63 +551,81 @@ void main() {
data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7]
));
const float db = d * 0.25 * (0.5 + (signs >> 28));
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = sign7 | (bitCount(sign7) << 7);
const uvec2 grid = iq2xxs_grid[qs];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4; // 0..3
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 4; // 0..3
const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const float db = d * 0.25 * (0.5 + scale);
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9;
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = sign7 | (bitCount(sign7) << 7);
const uvec2 grid = iq2xs_grid[qs & 511];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint ib = idx / 32; // 8 values per idx
const uint ib8 = idx % 32; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
const float d = float(data_a[ib].d);
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint ib = idx / 64; // 4 values per idx
const uint iqs = idx % 64; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d);
@ -631,33 +638,36 @@ void main() {
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
const uint grid = iq3xxs_grid[qs];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint ib = idx / 64; // 4 values per idx
const uint iqs = idx % 64; // 0..63
const uint iqh = iqs / 8;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh];
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

View File

@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1"))
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
load_vec_quant = "4";
if (tname == "bf16") {
@ -593,6 +593,10 @@ void process_shaders() {
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

View File

@ -1140,9 +1140,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
"GEGLU",
"SWIGLU",
"GEGLU_ERF",
"GEGLU_QUICK",
};
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@ -2768,6 +2770,48 @@ struct ggml_tensor * ggml_swiglu_split(
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
}
// ggml_geglu_erf
struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
}
struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
}
struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
}
// ggml_geglu_quick
struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
}
struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
}
struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
}
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(

View File

@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;
has_cpl = true;
}
}
}
@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() {
out_ids.clear();
n_used = 0;
used.clear();
used.resize(get_n_tokens(), false);
@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
++cur_idx;
@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false);
}
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) {
cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) {
break;
}
@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx);
used[idx] = true;
++n_used;
++cur_idx[s];
}
@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) {
break;

View File

@ -54,6 +54,7 @@ public:
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
@ -69,7 +70,8 @@ public:
llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
@ -112,6 +114,9 @@ private:
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@ -125,6 +130,8 @@ private:
// batch indices of the output
std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

View File

@ -1099,8 +1099,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp_kq_mask, "KQ_mask", -1);
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@ -1170,7 +1169,7 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1312,7 +1311,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@ -1376,7 +1375,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1390,7 +1389,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

View File

@ -228,8 +228,8 @@ public:
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -257,8 +257,8 @@ public:
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -293,10 +293,10 @@ public:
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -313,8 +313,8 @@ public:
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
const llama_cross * cross = nullptr;
};

View File

@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
auto ubatch = balloc.split_equal(n_ubatch, false);
if (ubatch.n_tokens == 0) {
break;
@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;

View File

@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break;

View File

@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (ubatch.n_tokens == 0) {
@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?

View File

@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (ubatch.n_tokens == 0) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}

File diff suppressed because it is too large Load Diff

View File

@ -1405,8 +1405,7 @@ struct clip_graph {
ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
// arrangement of BOI/EOI token embeddings
@ -1502,15 +1501,8 @@ struct clip_graph {
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1);
cur = ggml_mul(ctx0, x0, x1);
}
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
cur = ggml_swiglu_swapped(ctx0, cur);
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
@ -1769,35 +1761,42 @@ private:
cur = tmp;
}
// we only support parallel ffn for now
switch (type_op) {
case FFN_SILU:
{
if (gate) {
cur = ggml_swiglu_split(ctx0, cur, tmp);
cb(cur, "ffn_swiglu", il);
} else {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il);
} break;
case FFN_GELU:
{
if (gate) {
cur = ggml_geglu_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu", il);
} else {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
} break;
case FFN_GELU_ERF:
{
if (gate) {
cur = ggml_geglu_erf_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu_erf", il);
} else {
cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il);
cb(cur, "ffn_gelu_erf", il);
} break;
case FFN_GELU_QUICK:
{
if (gate) {
cur = ggml_geglu_quick_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu_quick", il);
} else {
cur = ggml_gelu_quick(ctx0, cur);
cb(cur, "ffn_relu", il);
cb(cur, "ffn_gelu_quick", il);
} break;
}
// we only support parallel ffn for now
if (gate) {
cur = ggml_mul(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il);
}
if (down) {
cur = ggml_mul_mat(ctx0, down, cur);
}

View File

@ -132,6 +132,28 @@ def test_chat_template():
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@pytest.mark.parametrize("prefill,re_prefill", [
("Whill", "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
])
def test_chat_template_assistant_prefill(prefill, re_prefill):
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
{"role": "assistant", "content": prefill},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
def test_apply_chat_template():
global server
server.chat_template = "command-r"
@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
])
def test_invalid_chat_completion_req(messages):
global server

View File

@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
/* Append assistant prefilled message */
if (prefill_assistant_message) {
chat_params.prompt += last_message.content;
if (!last_message.content_parts.empty()) {
for (auto & p : last_message.content_parts) {
chat_params.prompt += p.text;
}
} else {
chat_params.prompt += last_message.content;
}
}
llama_params["chat_format"] = static_cast<int>(chat_params.format);