diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index bbbec6a01a..4afd80eb45 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -136,6 +136,11 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b2ad13f194..949eac9a5a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -557,6 +557,8 @@ extern "C" { GGML_GLU_OP_REGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU_ERF, + GGML_GLU_OP_GEGLU_QUICK, GGML_GLU_OP_COUNT, }; @@ -1147,6 +1149,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + // A: n columns, r rows, // B: n columns, r rows, GGML_API struct ggml_tensor * ggml_glu_split( @@ -1170,6 +1188,16 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 69483de8f3..4d5c2c1825 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -67,6 +67,7 @@ #include #include #include +#include #include #include @@ -804,10 +805,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, nb[i] = nb[i - 1] * ne[i - 1]; } - ggml_cann_async_memset(ctx, buffer, n_bytes, 0); aclTensor* zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero); return zero; + GGML_UNUSED(n_bytes); } /** diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 11ff228f07..c5271b7757 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0fb2c08b5b..aaeee614ab 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_geglu_erf + +static void ggml_compute_forward_geglu_erf_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_erf_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_erf_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu_quick + +static void ggml_compute_forward_geglu_quick_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_quick_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_quick_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu( { ggml_compute_forward_swiglu(params, dst); } break; + case GGML_GLU_OP_GEGLU_ERF: + { + ggml_compute_forward_geglu_erf(params, dst); + } break; + case GGML_GLU_OP_GEGLU_QUICK: + { + ggml_compute_forward_geglu_quick(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index c432c99081..1f5857a23e 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_ } } +inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + float xi = x[i]; + y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i]; + } +} + +inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float xi = GGML_CPU_FP16_TO_FP32(x[i]); + float gi = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi); + } +} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i]; + } +} +#else +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]) * g[i]; + } +} +#endif + +inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 963e4d03dd..f77b2629a1 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_BF16: get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); @@ -210,6 +214,10 @@ void get_rows_cuda( ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_F16: ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1c04bba52e..b6b7960f12 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_GLU_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_cuda_op_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_cuda_op_geglu_quick(ctx, dst); + break; default: return false; } @@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]); default: return false; @@ -3192,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ba3c0f1376..f9c7b83c40 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary_gated(ctx, dst); } +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 9094f1d0ba..289d690e5c 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 5e5467e88a..40fc315e82 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -530,6 +530,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_REGLU, GGML_METAL_KERNEL_TYPE_GEGLU, GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_GEGLU_ERF, + GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, @@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); @@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node( case GGML_GLU_OP_SWIGLU: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; break; + case GGML_GLU_OP_GEGLU_ERF: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; + break; + case GGML_GLU_OP_GEGLU_QUICK: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline; + break; default: GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ebde005f33..22240bab47 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } void quantize_q4_0(device const float * src, device block_q4_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max float max = 0.0f; @@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) { } void quantize_q5_0(device const float * src, device block_q5_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max float max = 0.0f; @@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re } void quantize_q8_0(device const float * src, device block_q8_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max for (int j = 0; j < QK8_0; j++) { @@ -1258,6 +1261,50 @@ kernel void kernel_swiglu( } } +kernel void kernel_geglu_erf( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f+erf_approx(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_quick( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2450100b43..a9fc039038 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -398,12 +398,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; cl_kernel kernel_clamp; - cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, - kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16; + cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, + kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; cl_kernel kernel_rms_norm; cl_kernel kernel_group_norm; @@ -736,6 +737,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); GGML_LOG_CONT("."); @@ -753,12 +756,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_glu = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err)); - CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err)); - CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err)); - CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err)); + CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err)); + CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err)); GGML_LOG_CONT("."); } @@ -2262,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: @@ -2277,6 +2285,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); default: return false; @@ -3864,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_erf_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_erf; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6254,6 +6302,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_swiglu_f16; } break; + case GGML_GLU_OP_GEGLU_ERF: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_geglu_erf; + } else { + kernel = backend_ctx->kernel_geglu_erf_f16; + } + break; + case GGML_GLU_OP_GEGLU_QUICK: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_geglu_quick; + } else { + kernel = backend_ctx->kernel_geglu_quick_f16; + } + break; default: GGML_ABORT("Unsupported glu op"); } @@ -6368,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_gelu; break; + case GGML_UNARY_OP_GELU_ERF: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu_erf; + break; case GGML_UNARY_OP_GELU_QUICK: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl index 71c310cc9f..1ab426c774 100644 --- a/ggml/src/ggml-opencl/kernels/gelu.cl +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -6,6 +6,7 @@ #define GELU_COEF_A 0.044715f #define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f +#define SQRT_2_INV 0.70710678118654752440084436210484f kernel void kernel_gelu( global float * src0, @@ -35,6 +36,32 @@ kernel void kernel_gelu_4( dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_gelu_erf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV)); +} + +kernel void kernel_gelu_erf_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV)); +} + kernel void kernel_gelu_quick( global float * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/glu.cl b/ggml/src/ggml-opencl/kernels/glu.cl index ba861d8b18..7cca16e6a9 100644 --- a/ggml/src/ggml-opencl/kernels/glu.cl +++ b/ggml/src/ggml-opencl/kernels/glu.cl @@ -1,7 +1,9 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f +#define SQRT_2_INV 0.70710678118654752440084436210484f //------------------------------------------------------------------------------ // geglu @@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16( dst_row[i0] = silu*x1; } } + +//------------------------------------------------------------------------------ +// geglu_erf +//------------------------------------------------------------------------------ +kernel void kernel_geglu_erf( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_erf_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +//------------------------------------------------------------------------------ +// geglu_quick +//------------------------------------------------------------------------------ +kernel void kernel_geglu_quick( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + +kernel void kernel_geglu_quick_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index c7788bdb6b..0363b06a3e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6 } } +template +static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_gelu_erf(x[j0]) * g[j1]; + } +} + +template +static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_gelu_quick(x[j0]) * g[j1]; + } +} + namespace ggml_sycl_detail { static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, @@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten }); } +static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); @@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_swiglu(ctx, dst); } + +void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu_erf(ctx, dst); +} + +void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu_quick(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 86068b1012..50749e87d7 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b063a698de..21c81e99a1 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_GLU_OP_SWIGLU: ggml_sycl_swiglu(ctx, dst); break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_sycl_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_sycl_geglu_quick(ctx, dst); + break; default: return false; } @@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]); default: return false; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c0032cba21..7b8faf1787 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -456,6 +456,8 @@ struct vk_device_struct { vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; @@ -499,6 +501,8 @@ struct vk_device_struct { ggml_backend_buffer_type buffer_type; + bool disable_fusion; + #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif @@ -634,6 +638,7 @@ struct vk_flash_attn_push_constants { uint32_t nev3; uint32_t nem1; uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -649,8 +654,7 @@ struct vk_flash_attn_push_constants { float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -1089,8 +1093,8 @@ static size_t vk_skip_checks; static size_t vk_output_tensor; static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); -static void ggml_vk_check_results_0(ggml_tensor * tensor); -static void ggml_vk_check_results_1(ggml_tensor * tensor); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); #endif typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); @@ -2821,6 +2825,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_GLU(geglu) CREATE_GLU(reglu) CREATE_GLU(swiglu) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) #undef CREATE_GLU ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -3503,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->idx = idx; + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + return device; } @@ -6107,6 +6115,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; const uint32_t HSK = nek0; const uint32_t HSV = nev0; @@ -6174,7 +6183,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. @@ -6344,17 +6353,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } + uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2; + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, nem2, + nem1, nem2, nem3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1, + mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; ggml_vk_sync_buffers(subctx); @@ -6575,6 +6586,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; case GGML_GLU_OP_SWIGLU: return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; default: break; } @@ -7643,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -8874,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. @@ -8919,6 +8933,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: break; default: return false; @@ -9133,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr // fused rms_norm + mul ggml_tensor *mul = cgraph->nodes[node_idx + 1]; ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; - ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun); + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); } else { - ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun); + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); } break; case GGML_OP_RMS_NORM_BACK: @@ -9166,6 +9182,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); break; default: @@ -9293,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -9308,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -9384,6 +9403,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: buf = tensor->buffer; break; default: @@ -9416,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(tensor); + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); use_fence = true; #endif @@ -9436,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_1(tensor); + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); #endif } @@ -9883,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) { return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; } +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + mul->src[0]->ne[1] != rms_norm->ne[1]) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -9896,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -9966,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } - if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; } @@ -10194,6 +10246,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -10287,12 +10341,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) { - return false; - } // It's straightforward to support different K/V dequant, but would // significantly increase the number of pipelines if (op->src[1]->type != op->src[2]->type) { @@ -10747,11 +10795,21 @@ void * comp_result; size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; -static void ggml_vk_check_results_0(ggml_tensor * tensor) { +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + check_counter++; if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; @@ -10779,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { for (int i = 0; i < 6; i++) { ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } if (srci == nullptr) { continue; } @@ -10836,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_SUB) { tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { @@ -11027,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } - ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); - ggml_build_forward_expand(cgraph, tensor_clone); + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); - ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ggml_vk_print_tensor(tensor_clone, "tensor_clone"); @@ -11053,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); } -static void ggml_vk_check_results_1(ggml_tensor * tensor) { +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 788a5e065d..45c6e7736a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -101,8 +101,8 @@ void main() { uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -149,7 +149,7 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 6609f0bad3..7defe72b40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint32_t nev3; uint32_t nem1; uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -40,8 +41,7 @@ layout (push_constant) uniform parameter { float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -50,6 +50,9 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) @@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); return ACC_TYPE(pow(base, ACC_TYPE(exph))); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index e74e2fa934..486735fe8b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -126,8 +126,8 @@ void main() { uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -182,7 +182,7 @@ void main() { barrier(); } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8792d5195e..274f48fcab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -131,8 +131,8 @@ void main() { } uint32_t m_offset = 0; - if (p.nem2 != 1) { - m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -153,7 +153,7 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 0000000000..cbd4cb36bf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.comp" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 0000000000..3a2a6897bf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 26163b167c..888ce79f6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -500,10 +500,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +511,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +531,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +551,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +638,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 297a2a7711..30f78fabb3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; if (tname == "bf16") { @@ -593,6 +593,10 @@ void process_shaders() { string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b481193da3..75fc1e7072 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1140,9 +1140,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", "SWIGLU", + "GEGLU_ERF", + "GEGLU_QUICK", }; -static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); +static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -2768,6 +2770,48 @@ struct ggml_tensor * ggml_swiglu_split( return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); } +// ggml_geglu_erf + +struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false); +} + +struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true); +} + +struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false); +} + +// ggml_geglu_quick + +struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false); +} + +struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true); +} + +struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false); +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 91b1d6078a..3bc8554e51 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( // note: tracking the other way around is not necessary for now //seq_cpl[s0][s1] = true; + + has_cpl = true; } } } @@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { return ubatch_add(idxs, idxs.size(), false); } -llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { +llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { + if (sequential && has_cpl) { + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + + return {}; + } + std::vector cur_seq_set; + llama_seq_id last_seq_id = -1; + // determine the non-overlapping sequence sets participating in this ubatch for (int32_t i = 0; i < batch.n_tokens; ++i) { if (used[i]) { @@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } } + // accept only increasing sequence ids + if (sequential) { + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + } + if (add) { cur_seq_set.push_back(seq_set[i]); + last_seq_id = batch.seq_id[i][0]; + if (cur_seq_set.size() > n_ubatch) { break; } @@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c5376188..3420803ff9 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ public: uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -69,7 +70,8 @@ public: llama_ubatch split_simple(uint32_t n_ubatch); // make ubatches of equal-length sequences sets - llama_ubatch split_equal(uint32_t n_ubatch); + // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids + llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); // sequence-set-wise split - each ubatch contains a single sequence-set llama_ubatch split_seq(uint32_t n_ubatch); @@ -112,6 +114,9 @@ private: using pos_set_t = std::set; using seq_cpl_t = std::vector; + // helper flag to quickly determine if there are any coupled sequences in the batch + bool has_cpl; + std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -125,6 +130,8 @@ private: // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7c2e880066..55a059d097 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1099,8 +1099,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_kq_mask, "KQ_mask", -1); + inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1170,7 +1169,7 @@ static std::unique_ptr build_attn_inp_kv_unifie inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1312,7 +1311,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1376,7 +1375,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1390,7 +1389,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/src/llama-graph.h b/src/llama-graph.h index d8dc1e3307..54eaaac02b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -228,8 +228,8 @@ public: ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -257,8 +257,8 @@ public: ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -293,10 +293,10 @@ public: ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -313,8 +313,8 @@ public: ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index ee202cc710..fe207ad536 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; @@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch); + auto ubatch = balloc.split_equal(n_ubatch, false); if (ubatch.n_tokens == 0) { break; @@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ff22079851..d3129cc532 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos = prepare(ubatches); if (sinfos.empty()) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 03d974d852..6cd10db06b 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) { @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6ed84057cc..4b90dac7a3 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c76635793f..652856a35d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -24,10 +24,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -317,6 +319,538 @@ enum test_mode { MODE_GRAD, }; +// Output format support similar to llama-bench +enum output_formats { CONSOLE, SQL }; + +static const char * output_format_str(output_formats format) { + switch (format) { + case CONSOLE: + return "console"; + case SQL: + return "sql"; + default: + GGML_ABORT("invalid output format"); + } +} + +static bool output_format_from_str(const std::string & s, output_formats & format) { + if (s == "console") { + format = CONSOLE; + } else if (s == "sql") { + format = SQL; + } else { + return false; + } + return true; +} + +// Test result structure for SQL output +struct test_result { + std::string test_time; + std::string build_commit; + std::string backend_name; + std::string op_name; + std::string op_params; + std::string test_mode; + bool supported; + bool passed; + std::string error_message; + double time_us; + double flops; + double bandwidth_gb_s; + size_t memory_kb; + int n_runs; + + test_result() { + // Initialize with default values + time_us = 0.0; + flops = 0.0; + bandwidth_gb_s = 0.0; + memory_kb = 0; + n_runs = 0; + supported = false; + passed = false; + + // Set test time + time_t t = time(NULL); + char buf[32]; + std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); + test_time = buf; + + // Set build info + build_commit = ggml_commit(); + } + + test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params, + const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "", + double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0, + int n_runs = 0) : + backend_name(backend_name), + op_name(op_name), + op_params(op_params), + test_mode(test_mode), + supported(supported), + passed(passed), + error_message(error_message), + time_us(time_us), + flops(flops), + bandwidth_gb_s(bandwidth_gb_s), + memory_kb(memory_kb), + n_runs(n_runs) { + // Set test time + time_t t = time(NULL); + char buf[32]; + std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); + test_time = buf; + + // Set build info + build_commit = ggml_commit(); + } + + static const std::vector & get_fields() { + static const std::vector fields = { + "test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode", "supported", + "passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs" + }; + return fields; + } + + enum field_type { STRING, BOOL, INT, FLOAT }; + + static field_type get_field_type(const std::string & field) { + if (field == "supported" || field == "passed") { + return BOOL; + } + if (field == "memory_kb" || field == "n_runs") { + return INT; + } + if (field == "time_us" || field == "flops" || field == "bandwidth_gb_s") { + return FLOAT; + } + return STRING; + } + + std::vector get_values() const { + return { test_time, + build_commit, + backend_name, + op_name, + op_params, + test_mode, + std::to_string(supported), + std::to_string(passed), + error_message, + std::to_string(time_us), + std::to_string(flops), + std::to_string(bandwidth_gb_s), + std::to_string(memory_kb), + std::to_string(n_runs) }; + } +}; + +// Printer classes for different output formats +enum class test_status_t { NOT_SUPPORTED, OK, FAIL }; + +struct test_operation_info { + std::string op_name; + std::string op_params; + std::string backend_name; + test_status_t status = test_status_t::OK; + std::string failure_reason; + + // Additional information fields that were previously in separate structs + std::string error_component; + std::string error_details; + + // Gradient info + int64_t gradient_index = -1; + std::string gradient_param_name; + float gradient_value = 0.0f; + + // MAA error info + double maa_error = 0.0; + double maa_threshold = 0.0; + + // Flags for different types of information + bool has_error = false; + bool has_gradient_info = false; + bool has_maa_error = false; + bool is_compare_failure = false; + bool is_large_tensor_skip = false; + + test_operation_info() = default; + + test_operation_info(const std::string & op_name, const std::string & op_params, const std::string & backend_name, + test_status_t status = test_status_t::OK, const std::string & failure_reason = "") : + op_name(op_name), + op_params(op_params), + backend_name(backend_name), + status(status), + failure_reason(failure_reason) {} + + // Set error information + void set_error(const std::string & component, const std::string & details) { + has_error = true; + error_component = component; + error_details = details; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set gradient information + void set_gradient_info(int64_t index, const std::string & param_name, float value) { + has_gradient_info = true; + gradient_index = index; + gradient_param_name = param_name; + gradient_value = value; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set MAA error information + void set_maa_error(double error, double threshold) { + has_maa_error = true; + maa_error = error; + maa_threshold = threshold; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set compare failure + void set_compare_failure() { + is_compare_failure = true; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set large tensor skip + void set_large_tensor_skip() { is_large_tensor_skip = true; } +}; + +struct test_summary_info { + size_t tests_passed; + size_t tests_total; + bool is_backend_summary = false; // true for backend summary, false for test summary + + test_summary_info() = default; + + test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false) : + tests_passed(tests_passed), + tests_total(tests_total), + is_backend_summary(is_backend_summary) {} +}; + +struct testing_start_info { + size_t device_count; + + testing_start_info() = default; + + testing_start_info(size_t device_count) : device_count(device_count) {} +}; + +struct backend_init_info { + size_t device_index; + size_t total_devices; + std::string device_name; + bool skipped = false; + std::string skip_reason; + std::string description; + size_t memory_total_mb = 0; + size_t memory_free_mb = 0; + bool has_memory_info = false; + + backend_init_info() = default; + + backend_init_info(size_t device_index, size_t total_devices, const std::string & device_name, bool skipped = false, + const std::string & skip_reason = "", const std::string & description = "", + size_t memory_total_mb = 0, size_t memory_free_mb = 0, bool has_memory_info = false) : + device_index(device_index), + total_devices(total_devices), + device_name(device_name), + skipped(skipped), + skip_reason(skip_reason), + description(description), + memory_total_mb(memory_total_mb), + memory_free_mb(memory_free_mb), + has_memory_info(has_memory_info) {} +}; + +struct backend_status_info { + std::string backend_name; + test_status_t status; + + backend_status_info() = default; + + backend_status_info(const std::string & backend_name, test_status_t status) : + backend_name(backend_name), + status(status) {} +}; + +struct overall_summary_info { + size_t backends_passed; + size_t backends_total; + bool all_passed; + + overall_summary_info() = default; + + overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed) : + backends_passed(backends_passed), + backends_total(backends_total), + all_passed(all_passed) {} +}; + +struct printer { + virtual ~printer() {} + + FILE * fout = stdout; + + virtual void print_header() {} + + virtual void print_test_result(const test_result & result) = 0; + + virtual void print_footer() {} + + virtual void print_operation(const test_operation_info & info) { (void) info; } + + virtual void print_summary(const test_summary_info & info) { (void) info; } + + virtual void print_testing_start(const testing_start_info & info) { (void) info; } + + virtual void print_backend_init(const backend_init_info & info) { (void) info; } + + virtual void print_backend_status(const backend_status_info & info) { (void) info; } + + virtual void print_overall_summary(const overall_summary_info & info) { (void) info; } +}; + +struct console_printer : public printer { + void print_test_result(const test_result & result) override { + if (result.test_mode == "test") { + print_test_console(result); + } else if (result.test_mode == "perf") { + print_perf_console(result); + } + } + + void print_operation(const test_operation_info & info) override { + printf(" %s(%s): ", info.op_name.c_str(), info.op_params.c_str()); + fflush(stdout); + + // Handle large tensor skip first + if (info.is_large_tensor_skip) { + printf("skipping large tensors for speed \n"); + return; + } + + // Handle not supported status + if (info.status == test_status_t::NOT_SUPPORTED) { + if (!info.failure_reason.empty()) { + printf("not supported [%s]\n", info.failure_reason.c_str()); + } else { + printf("not supported [%s]\n", info.backend_name.c_str()); + } + return; + } + + // Handle errors and additional information + if (info.has_error) { + if (info.error_component == "allocation") { + fprintf(stderr, "failed to allocate tensors [%s] ", info.backend_name.c_str()); + } else if (info.error_component == "backend") { + fprintf(stderr, " Failed to initialize %s backend\n", info.backend_name.c_str()); + } else { + fprintf(stderr, "Error in %s: %s\n", info.error_component.c_str(), info.error_details.c_str()); + } + } + + // Handle gradient info + if (info.has_gradient_info) { + printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", info.op_name.c_str(), info.gradient_index, + info.gradient_param_name.c_str(), info.gradient_value); + } + + // Handle MAA error + if (info.has_maa_error) { + printf("[%s] MAA = %.9f > %.9f ", info.op_name.c_str(), info.maa_error, info.maa_threshold); + } + + // Handle compare failure + if (info.is_compare_failure) { + printf("compare failed "); + } + + // Print final status + if (info.status == test_status_t::OK) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_summary(const test_summary_info & info) override { + if (info.is_backend_summary) { + printf("%zu/%zu backends passed\n", info.tests_passed, info.tests_total); + } else { + printf(" %zu/%zu tests passed\n", info.tests_passed, info.tests_total); + } + } + + void print_backend_status(const backend_status_info & info) override { + printf(" Backend %s: ", info.backend_name.c_str()); + if (info.status == test_status_t::OK) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_testing_start(const testing_start_info & info) override { + printf("Testing %zu devices\n\n", info.device_count); + } + + void print_backend_init(const backend_init_info & info) override { + printf("Backend %zu/%zu: %s\n", info.device_index + 1, info.total_devices, info.device_name.c_str()); + + if (info.skipped) { + printf(" %s\n", info.skip_reason.c_str()); + return; + } + + if (!info.description.empty()) { + printf(" Device description: %s\n", info.description.c_str()); + } + + if (info.has_memory_info) { + printf(" Device memory: %zu MB (%zu MB free)\n", info.memory_total_mb, info.memory_free_mb); + } + + printf("\n"); + } + + void print_overall_summary(const overall_summary_info & info) override { + printf("%zu/%zu backends passed\n", info.backends_passed, info.backends_total); + if (info.all_passed) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + private: + void print_test_console(const test_result & result) { + printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); + fflush(stdout); + + if (!result.supported) { + printf("not supported [%s] ", result.backend_name.c_str()); + printf("\n"); + return; + } + + if (result.passed) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_perf_console(const test_result & result) { + int len = printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); + fflush(stdout); + + if (!result.supported) { + printf("not supported\n"); + return; + } + + // align while also leaving some margin for variations in parameters + int align = 8; + int last = (len + align - 1) / align * align; + if (last - len < 5) { + last += align; + } + printf("%*s", last - len, ""); + + printf(" %8d runs - %8.2f us/run - ", result.n_runs, result.time_us); + + if (result.flops > 0) { + auto format_flops = [](double flops) -> std::string { + char buf[256]; + if (flops >= 1e12) { + snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); + } else if (flops >= 1e9) { + snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); + } else if (flops >= 1e6) { + snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); + } else { + snprintf(buf, sizeof(buf), "%6.2f kFLOP", flops / 1e3); + } + return buf; + }; + uint64_t op_flops_per_run = result.flops * result.time_us / 1e6; + printf("%s/run - \033[1;34m%sS\033[0m", format_flops(op_flops_per_run).c_str(), + format_flops(result.flops).c_str()); + } else { + printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", result.memory_kb, result.bandwidth_gb_s); + } + printf("\n"); + } +}; + +struct sql_printer : public printer { + static std::string get_sql_field_type(const std::string & field) { + switch (test_result::get_field_type(field)) { + case test_result::STRING: + return "TEXT"; + case test_result::BOOL: + case test_result::INT: + return "INTEGER"; + case test_result::FLOAT: + return "REAL"; + default: + GGML_ABORT("invalid field type"); + } + } + + void print_header() override { + std::vector fields = test_result::get_fields(); + fprintf(fout, "CREATE TABLE IF NOT EXISTS test_backend_ops (\n"); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, " %s %s%s\n", fields[i].c_str(), get_sql_field_type(fields[i]).c_str(), + i < fields.size() - 1 ? "," : ""); + } + fprintf(fout, ");\n\n"); + } + + void print_test_result(const test_result & result) override { + fprintf(fout, "INSERT INTO test_backend_ops ("); + std::vector fields = test_result::get_fields(); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, "%s%s", fields[i].c_str(), i < fields.size() - 1 ? ", " : ""); + } + fprintf(fout, ") VALUES ("); + std::vector values = result.get_values(); + for (size_t i = 0; i < values.size(); i++) { + fprintf(fout, "'%s'%s", values[i].c_str(), i < values.size() - 1 ? ", " : ""); + } + fprintf(fout, ");\n"); + } +}; + +static std::unique_ptr create_printer(output_formats format) { + switch (format) { + case CONSOLE: + return std::make_unique(); + case SQL: + return std::make_unique(); + } + GGML_ABORT("invalid output format"); +} + struct test_case { virtual ~test_case() {} @@ -434,7 +968,7 @@ struct test_case { return t; } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) { + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -451,29 +985,33 @@ struct test_case { add_sentinel(ctx); ggml_tensor * out = build_graph(ctx); - - if (op_name != nullptr && op_desc(out) != op_name) { + std::string current_op_name = op_desc(out); + if (op_name != nullptr && current_op_name != op_name) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - // check if the backends support the ops bool supported = true; for (ggml_backend_t backend : {backend1, backend2}) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; break; } } } + if (!supported) { - printf("\n"); + // Create test result for unsupported operation + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", + false, false, "not supported"); + + if (output_printer) { + output_printer->print_test_result(result); + } + ggml_free(ctx); return true; } @@ -578,24 +1116,24 @@ struct test_case { const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr); - if (!cmp_ok) { - printf("compare failed "); - } - ggml_backend_buffer_free(buf); ggml_free(ctx); - if (ud.ok && cmp_ok) { - printf("\033[1;32mOK\033[0m\n"); - return true; + // Create test result + bool test_passed = ud.ok && cmp_ok; + std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed"); + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed, + error_msg); + + if (output_printer) { + output_printer->print_test_result(result); } - printf("\033[1;31mFAIL\033[0m\n"); - return false; + return test_passed; } - bool eval_perf(ggml_backend_t backend, const char * op_name) { + bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) { mode = MODE_PERF; static const size_t graph_nodes = 8192; @@ -608,30 +1146,26 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx.get()); - - if (op_name != nullptr && op_desc(out) != op_name) { + ggml_tensor * out = build_graph(ctx.get()); + std::string current_op_name = op_desc(out); + if (op_name != nullptr && current_op_name != op_name) { //printf(" %s: skipping\n", op_desc(out).c_str()); return true; } - int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - // check if backends support op if (!ggml_backend_supports_op(backend, out)) { - printf("not supported\n"); + // Create test result for unsupported performance test + test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false, + "not supported"); + + if (output_printer) { + output_printer->print_test_result(result); + } + return true; } - // align while also leaving some margin for variations in parameters - int align = 8; - int last = (len + align - 1) / align * align; - if (last - len < 5) { - last += align; - } - printf("%*s", last - len, ""); - // allocate ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr @@ -715,40 +1249,24 @@ struct test_case { total_runs += n_runs; } while (total_time_us < 1000*1000); // run for at least 1 second - printf(" %8d runs - %8.2f us/run - ", - total_runs, - (double)total_time_us / total_runs); + // Create test result + double avg_time_us = (double) total_time_us / total_runs; + double calculated_flops = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0; + double calculated_bandwidth = + (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0; + size_t calculated_memory_kb = op_size(out) / 1024; - if (op_flops(out) > 0) { - double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6); - auto format_flops = [](double flops) -> std::string { - char buf[256]; - if (flops >= 1e12) { - snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); - } else if (flops >= 1e9) { - snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); - } else if (flops >= 1e6) { - snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); - } else { - snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3); - } - return buf; - }; - printf("%s/run - \033[1;34m%sS\033[0m", - format_flops(op_flops(out)).c_str(), - format_flops(flops_per_sec).c_str()); + test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", true, true, "", avg_time_us, + calculated_flops, calculated_bandwidth, calculated_memory_kb, total_runs); - } else { - printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", - op_size(out) / 1024, - total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0); + if (output_printer) { + output_printer->print_test_result(result); } - printf("\n"); return true; } - bool eval_grad(ggml_backend_t backend, const char * op_name) { + bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) { mode = MODE_GRAD; const std::vector expect = grad_expect(); @@ -766,42 +1284,47 @@ struct test_case { ggml_tensor * out = build_graph(ctx.get()); if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) { - //printf(" %s: skipping\n", op_desc(out).c_str()); return true; } - printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - if (out->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32]\n", out->name); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + out->name + std::string("->type != FP32"))); return true; } + // Print operation info first + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend))); + // check if the backend supports the ops - bool supported = true; - bool any_params = false; + bool supported = true; + bool any_params = false; + std::string failure_reason; + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); - supported = false; + supported = false; + failure_reason = ggml_backend_name(backend); break; } if ((t->flags & GGML_TENSOR_FLAG_PARAM)) { any_params = true; if (t->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32] ", t->name); - supported = false; + supported = false; + failure_reason = std::string(t->name) + "->type != FP32"; break; } } } if (!any_params) { - printf("not supported [%s] \n", op_desc(out).c_str()); - supported = false; + supported = false; + failure_reason = op_desc(out); } + if (!supported) { - printf("\n"); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, failure_reason)); return true; } @@ -812,7 +1335,9 @@ struct test_case { } } if (ngrads > grad_nmax()) { - printf("skipping large tensors for speed \n"); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_large_tensor_skip(); + output_printer->print_operation(info); return true; } @@ -835,25 +1360,30 @@ struct test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + ggml_backend_name(backend))); supported = false; break; } if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32] ", t->name); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + std::string(t->name) + "->type != FP32")); supported = false; break; } } if (!supported) { - printf("\n"); return true; } // allocate ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr if (buf == NULL) { - printf("failed to allocate tensors [%s] ", ggml_backend_name(backend)); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_error("allocation", ""); + output_printer->print_operation(info); return false; } @@ -891,7 +1421,9 @@ struct test_case { for (int64_t i = 0; i < ne; ++i) { // gradient algebraic // check for nans if (!std::isfinite(ga[i])) { - printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_gradient_info(i, bn, ga[i]); + output_printer->print_operation(info); ok = false; break; } @@ -959,7 +1491,9 @@ struct test_case { const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect); if (err > max_maa_err()) { - printf("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err()); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_maa_error(err, max_maa_err()); + output_printer->print_operation(info); ok = false; break; } @@ -968,16 +1502,18 @@ struct test_case { } } + // Create final test result + test_operation_info final_info(op_desc(out), vars(), ggml_backend_name(backend)); if (!ok) { - printf("compare failed "); + final_info.set_compare_failure(); } + final_info.status = ok ? test_status_t::OK : test_status_t::FAIL; + output_printer->print_operation(final_info); if (ok) { - printf("\033[1;32mOK\033[0m\n"); return true; } - printf("\033[1;31mFAIL\033[0m\n"); return false; } }; @@ -2047,10 +2583,6 @@ struct test_rms_norm_mul : public test_case { } } - double max_nmse_err() override { - return 1e-6; - } - float grad_eps() override { return 1.0f; } @@ -4522,7 +5054,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } - for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } @@ -4989,7 +5521,8 @@ static std::vector> make_test_cases_perf() { return test_cases; } -static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) { +static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter, + printer * output_printer) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -5012,17 +5545,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op filter_test_cases(test_cases, params_filter); ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); if (backend_cpu == NULL) { - printf(" Failed to initialize CPU backend\n"); + test_operation_info info("", "", "CPU"); + info.set_error("backend", "Failed to initialize CPU backend"); + output_printer->print_operation(info); return false; } size_t n_ok = 0; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_name)) { + if (test->eval(backend, backend_cpu, op_name, output_printer)) { n_ok++; } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); ggml_backend_free(backend_cpu); @@ -5034,11 +5569,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op filter_test_cases(test_cases, params_filter); size_t n_ok = 0; for (auto & test : test_cases) { - if (test->eval_grad(backend, op_name)) { + if (test->eval_grad(backend, op_name, output_printer)) { n_ok++; } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); return n_ok == test_cases.size(); } @@ -5047,7 +5582,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op auto test_cases = make_test_cases_perf(); filter_test_cases(test_cases, params_filter); for (auto & test : test_cases) { - test->eval_perf(backend, op_name); + test->eval_perf(backend, op_name, output_printer); } return true; } @@ -5056,16 +5591,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } static void usage(char ** argv) { - printf("Usage: %s [mode] [-o ] [-b ] [-p ]\n", argv[0]); + printf("Usage: %s [mode] [-o ] [-b ] [-p ] [--output ]\n", argv[0]); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); printf(" - perf (performance evaluation)\n"); printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n"); + printf(" --output specifies output format (default: console)\n"); } int main(int argc, char ** argv) { test_mode mode = MODE_TEST; + output_formats output_format = CONSOLE; const char * op_name_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; @@ -5098,6 +5635,16 @@ int main(int argc, char ** argv) { usage(argv); return 1; } + } else if (strcmp(argv[i], "--output") == 0) { + if (i + 1 < argc) { + if (!output_format_from_str(argv[++i], output_format)) { + usage(argv); + return 1; + } + } else { + usage(argv); + return 1; + } } else { usage(argv); return 1; @@ -5107,23 +5654,29 @@ int main(int argc, char ** argv) { // load and enumerate backends ggml_backend_load_all(); - printf("Testing %zu devices\n\n", ggml_backend_dev_count()); + // Create printer for output format + std::unique_ptr output_printer = create_printer(output_format); + if (output_printer) { + output_printer->print_header(); + } + + output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count())); size_t n_ok = 0; for (size_t i = 0; i < ggml_backend_dev_count(); i++) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); - printf("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev)); - if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) { - printf(" Skipping\n"); + output_printer->print_backend_init( + backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping")); n_ok++; continue; } if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) { - printf(" Skipping CPU backend\n"); + output_printer->print_backend_init(backend_init_info( + i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend")); n_ok++; continue; } @@ -5138,36 +5691,35 @@ int main(int argc, char ** argv) { ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); } - printf(" Device description: %s\n", ggml_backend_dev_description(dev)); - size_t free, total; // NOLINT + size_t free, total; // NOLINT ggml_backend_dev_memory(dev, &free, &total); - printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024); - printf("\n"); + output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), + false, "", ggml_backend_dev_description(dev), + total / 1024 / 1024, free / 1024 / 1024, true)); - bool ok = test_backend(backend, mode, op_name_filter, params_filter); + bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get()); - printf(" Backend %s: ", ggml_backend_name(backend)); if (ok) { - printf("\033[1;32mOK\033[0m\n"); n_ok++; - } else { - printf("\033[1;31mFAIL\033[0m\n"); } - - printf("\n"); + output_printer->print_backend_status( + backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL)); ggml_backend_free(backend); } ggml_quantize_free(); - printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count()); + if (output_printer) { + output_printer->print_footer(); + } + + output_printer->print_overall_summary( + overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count())); if (n_ok != ggml_backend_dev_count()) { - printf("\033[1;31mFAIL\033[0m\n"); return 1; } - printf("\033[1;32mOK\033[0m\n"); return 0; } diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a990520ed3..9146c9e9c4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1405,8 +1405,7 @@ struct clip_graph { ggml_tensor * x = embeddings; embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_silu_inplace(ctx0, embeddings); - embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_swiglu_split(ctx0, embeddings, x); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } // arrangement of BOI/EOI token embeddings @@ -1502,15 +1501,8 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); // swiglu - { - int64_t split_point = cur->ne[0] / 2; - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half - x1 = ggml_silu(ctx0, x1); - cur = ggml_mul(ctx0, x0, x1); - } + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + cur = ggml_swiglu_swapped(ctx0, cur); // mid-norm cur = ggml_rms_norm(ctx0, cur, 1e-6); @@ -1769,35 +1761,42 @@ private: cur = tmp; } + // we only support parallel ffn for now switch (type_op) { case FFN_SILU: - { + if (gate) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case FFN_GELU: - { + if (gate) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); } break; case FFN_GELU_ERF: - { + if (gate) { + cur = ggml_geglu_erf_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_erf", il); + } else { cur = ggml_gelu_erf(ctx0, cur); - cb(cur, "ggml_gelu_erf", il); + cb(cur, "ffn_gelu_erf", il); } break; case FFN_GELU_QUICK: - { + if (gate) { + cur = ggml_geglu_quick_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_quick", il); + } else { cur = ggml_gelu_quick(ctx0, cur); - cb(cur, "ffn_relu", il); + cb(cur, "ffn_gelu_quick", il); } break; } - // we only support parallel ffn for now - if (gate) { - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "ffn_gate_par", il); - } - if (down) { cur = ggml_mul_mat(ctx0, down, cur); } diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 1b5205f79d..7ee9a16514 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -132,6 +132,28 @@ def test_chat_template(): assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +@pytest.mark.parametrize("prefill,re_prefill", [ + ("Whill", "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), +]) +def test_chat_template_assistant_prefill(prefill, re_prefill): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + {"role": "assistant", "content": prefill}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + + def test_apply_chat_template(): global server server.chat_template = "command-r" @@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re [{"role": "system", "content": 123}], # [{"content": "hello"}], # TODO: should not be a valid case [{"role": "system", "content": "test"}, {}], + [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}], ]) def test_invalid_chat_completion_req(messages): global server diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 2ef9a16451..6c2e91359a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse( /* Append assistant prefilled message */ if (prefill_assistant_message) { - chat_params.prompt += last_message.content; + if (!last_message.content_parts.empty()) { + for (auto & p : last_message.content_parts) { + chat_params.prompt += p.text; + } + } else { + chat_params.prompt += last_message.content; + } } llama_params["chat_format"] = static_cast(chat_params.format);