diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md index e52baffdff..ce6c7b5605 100644 --- a/docs/backend/OPENCL.md +++ b/docs/backend/OPENCL.md @@ -17,7 +17,7 @@ OpenCL (Open Computing Language) is an open, royalty-free standard for cross-pla ### Llama.cpp + OpenCL -The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal. +The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs such as those that do not have [SYCL](/docs/backend/SYCL.md) support although the performance is not optimal. ## OS diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 3688abdd58..2180a06fd0 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3702,3 +3702,106 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // conv_x + ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + // This op is currently defined only for F32 in ggml_cpu + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + // Shapes follow ggml_compute_forward_ssm_conv_f32 + const int64_t nc = src1->ne[0]; // d_conv + const int64_t ncs = src0->ne[0]; // d_conv - 1 + n_t + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_s = src0->ne[2]; // n_seqs + + const int64_t n_t = dst->ne[1]; // tokens per sequence + + GGML_ASSERT(dst->ne[0] == nr); // dst: {d_inner, n_t, n_s} + GGML_ASSERT(src1->ne[1] == nr); // weight: {d_conv, d_inner} + GGML_ASSERT(ncs == nc - 1 + n_t); // conv_x: {d_conv - 1 + n_t, d_inner, n_s} + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + // --- Build CANN tensors --- + + // 1) Input: conv_x as NCL + // + // src0->ne = { ncs, nr, n_s, 1 } // {L_in, C, N} + // Passing ACL_FORMAT_NCL here means: + // reversed dims -> [N, C, L_in] = [n_s, nr, ncs] + acl_tensor_ptr acl_x = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); + + // 2) Weights: depthwise conv kernel, view src1 as {K, 1, C} + // + // src1 original: ne = { nc, nr, 1, 1 } // [K, C, 1, 1] + // we want a view: ne_w = { nc, 1, nr } // [K, 1, C] + // so that reversed dims -> [C, 1, K] which matches + // [out_channels, in_channels/groups, kernel_size] + int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] + // Layout: src1 data is [K, C] with + // offset(k, c) = k*nb0 + c*nb1 + // We want offset_w(k, 0, c) = k*nb0 + c*nb1, + // so we can reuse nb0 and nb1, and set nb2 = nb1. + size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 + + acl_tensor_ptr acl_w = ggml_cann_create_tensor( + src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); + + // 3) Output: dst is { d_inner, n_t, n_s } (CLN) + // + // We need an NCL view of the same buffer: + // desired NCL logical shape: { L_out = n_t, C = nr, N = n_s } + // + // Original CLN layout: + // dst->ne = { nr, n_t, n_s } + // dst->nb[0] = sizeof(float) + // dst->nb[1] = nr * sizeof(float) + // dst->nb[2] = nr * n_t * sizeof(float) + // + // We want offset_new(L, C, N) = offset_orig(C, L, N). + // Choose: + // nb_y[0] = nr * sizeof(float); // step in L + // nb_y[1] = sizeof(float); // step in C + // nb_y[2] = nr * n_t * sizeof(float); // step in N + int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] + size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t] + + acl_tensor_ptr acl_y = ggml_cann_create_tensor( + dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); + + // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") --- + int64_t strideVal[1] = { 1 }; + int64_t paddingVal[1] = { 0 }; + int64_t dilationVal[1] = { 1 }; + + acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); + acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); + acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); + + const bool transposed = false; + const int64_t groups = nr; // depthwise: one group per inner dim + int8_t cubeMathType = 0; + +#ifdef ASCEND_310P + cubeMathType = 1; +#endif + + GGML_CANN_CALL_ACLNN_OP(ctx, + Convolution, + acl_x.get(), // input: N, C, L_in = ncs + acl_w.get(), // weight: [C, 1, K] with groups=nr + nullptr, // bias + stride.get(), + padding.get(), + dilation.get(), + transposed, + padding.get(), // output padding (unused for non-transposed) + groups, + acl_y.get(), + cubeMathType); +} + diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 6ec4640289..a6ea016c54 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1033,6 +1033,8 @@ void ggml_cann_op_unary(std::functionpipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -8432,7 +8434,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF); + uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS)); switch (mode) { case GGML_SCALE_MODE_NEAREST: return ctx->device->pipeline_upscale_nearest_f32; @@ -8440,6 +8442,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_upscale_bilinear_f32; case GGML_SCALE_MODE_BICUBIC: return ctx->device->pipeline_upscale_bicubic_f32; + case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS: + return ctx->device->pipeline_upscale_bilinear_antialias_f32; default: return nullptr; } @@ -9090,10 +9094,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; } break; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; break; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + uint32_t nrows = (uint32_t)ggml_nrows(src0); + uint32_t z = 1; + if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) { + z = CEIL_DIV(nrows, 32768); + nrows = 32768; + } + elements = { nrows, (uint32_t)ne00, z }; + + } break; case GGML_OP_GET_ROWS: elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); @@ -10021,7 +10035,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, (uint32_t)src0->ne[2], nb01, nb02, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, @@ -14330,7 +14344,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) { + return false; + } + } + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 7c1fb1cd22..f7587468a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_multi(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 68f00c180b..acb8ed7815 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_neox(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 28a939ec6a..0033cdb224 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_norm(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 82f39cee34..939cf3c51c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -6,6 +6,7 @@ struct rope_params { uint rope_mode; uint ncols; + uint nrows; uint n_dims; float freq_scale; uint p_delta_rows; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index ea1e0fdb41..d93800b5e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_vision(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 037ab0c78f..f7d12a8dda 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -21,6 +21,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; #define NEAREST 0 #define BILINEAR 1 #define BICUBIC 2 +#define BILINEAR_ANTIALIAS 513 layout (constant_id = 0) const uint scale_mode = 0; @@ -62,6 +63,56 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { return fetch_bilinear(c0, c1, d, i12, i13); } +float triangle_filter(float x) { + return max(1.0f - abs(x), 0.0f); +} + +float interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) { + const float support1 = max(1.0f, 1.0f / p.sf1); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f, 1.0f / p.sf0); + const float invscale0 = 1.0f / support0; + + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + const float y = (float(i11) + p.pixel_offset) / p.sf1; + const float x = (float(i10) + p.pixel_offset) / p.sf0; + + // the range of source pixels that contribute + const int x_min = max(int(x - support0 + p.pixel_offset), 0); + const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00)); + const int y_min = max(int(y - support1 + p.pixel_offset), 0); + const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1); + + for (int sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00]; + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + return val; +} + // Bicubic interpolation with alpha = -0.75 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0); @@ -118,6 +169,9 @@ void main() { case BICUBIC: result = interpolate_bicubic(i10, i11, i12, i13); break; + case BILINEAR_ANTIALIAS: + result = interpolate_bilinear_antialias(i10, i11, i12, i13); + break; } data_d[p.d_offset + idx] = D_TYPE(result); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6b65f6e1c7..0b981b1788 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -402,12 +402,20 @@ static std::string var_to_str(ggml_op_pool pool) { } static std::string var_to_str(ggml_scale_mode mode) { - switch (mode) { - case GGML_SCALE_MODE_NEAREST: return "nearest"; - case GGML_SCALE_MODE_BILINEAR: return "bilinear"; - case GGML_SCALE_MODE_BICUBIC: return "bicubic"; - default: return std::to_string(mode); + std::string str; + switch (mode & 0xFF) { + case GGML_SCALE_MODE_NEAREST: str = "nearest"; break; + case GGML_SCALE_MODE_BILINEAR: str = "bilinear"; break; + case GGML_SCALE_MODE_BICUBIC: str = "bicubic"; break; + default: str = std::to_string(mode); break; } + if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { + str += "|align_corners"; + } + if (mode & GGML_SCALE_FLAG_ANTIALIAS) { + str += "|antialias"; + } + return str; } #define VAR_TO_STR(x) (#x "=" + var_to_str(x)) @@ -5535,18 +5543,16 @@ struct test_interpolate : public test_case { const ggml_type type; const std::array ne; const std::array ne_tgt; - const uint32_t mode = GGML_SCALE_MODE_NEAREST; + const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST; std::string vars() override { - ggml_scale_mode mode = (ggml_scale_mode)(this->mode & 0xFF); - std::string flags = (this->mode & GGML_SCALE_FLAG_ALIGN_CORNERS) ? "align_corners" : "none"; - return VARS_TO_STR5(type, ne, ne_tgt, mode, flags); + return VARS_TO_STR4(type, ne, ne_tgt, mode); } test_interpolate(ggml_type type = GGML_TYPE_F32, std::array ne = {2, 5, 7, 11}, std::array ne_tgt = {5, 7, 11, 13}, - uint32_t mode = GGML_SCALE_MODE_NEAREST) + ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST) : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {} ggml_tensor * build_graph(ggml_context * ctx) override { @@ -7775,6 +7781,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B + test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); } if (all) { @@ -7789,6 +7796,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); } if (all) { @@ -7802,6 +7810,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl) + test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); } test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) @@ -7880,9 +7889,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); } for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { - test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS)); - test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS)); - test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, mode | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS))); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS))); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS))); } test_cases.emplace_back(new test_sum()); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 94825dc862..1abbf6d6d9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1007,8 +1007,10 @@ private: return ret; } - void clear_slot(server_slot & slot) const { - GGML_ASSERT(!slot.is_processing()); + void clear_slot(server_slot & slot, bool allow_processing = false) const { + if (!allow_processing) { + GGML_ASSERT(!slot.is_processing()); + } SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); @@ -2336,7 +2338,7 @@ private: if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - clear_slot(slot); + clear_slot(slot, /*allow_processing=*/true); // there is no common part left slot.n_prompt_tokens_cache = 0;