ggml : extend ggml_pool_1d + metal (#16429)
* chore: resolve conflicts * feat: ggml metal impl * fix: ggml_metal_kargs_pool_1d struct * fix: require contiguous input * chore: test pool_1d * chore: limit pool1d test cases to p0=0 and s0=k0 to conform with asserts * chore: add p0 and s0 to testing * fix: allow padding for cpu and metal * Update ggml/src/ggml-metal/ggml-metal.metal * fix: correct single-threaded loop * ggml : cleanup * tests : add ne[1] != 1 tests * fix: ne[1] handling in np * cont : fixes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6ba6a3c76f
commit
388ce82241
|
|
@ -7,10 +7,9 @@
|
|||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
|
|
@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pool_1d_sk_p0
|
||||
|
||||
static void ggml_compute_forward_pool_1d_sk_p0(
|
||||
// ggml_compute_forward_pool_1d_ksp
|
||||
static void ggml_compute_forward_pool_1d_ksp(
|
||||
const ggml_compute_params * params,
|
||||
const ggml_op_pool op,
|
||||
const int k,
|
||||
const int s,
|
||||
const int p,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src = dst->src[0];
|
||||
|
|
@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
|
|||
return;
|
||||
}
|
||||
|
||||
const char * cdata = (const char *)src->data;
|
||||
const char * const data_end = cdata + ggml_nbytes(src);
|
||||
float * drow = (float *)dst->data;
|
||||
const int64_t IW = src->ne[0];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
||||
const int64_t rs = dst->ne[0];
|
||||
const int64_t nr = ggml_nrows(src);
|
||||
|
||||
while (cdata < data_end) {
|
||||
const void * srow = (const void *)cdata;
|
||||
int j = 0;
|
||||
for (int64_t i = 0; i < rs; ++i) {
|
||||
for (int64_t ir = 0; ir < nr; ++ir) {
|
||||
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
|
||||
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
|
||||
|
||||
for (int64_t ow = 0; ow < OW; ++ow) {
|
||||
float res = 0;
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] = 0; break;
|
||||
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_AVG: res = 0.0f; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int count = 0;
|
||||
const int base = (int) ow * s - p;
|
||||
|
||||
for (int ki = 0; ki < k; ++ki) {
|
||||
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= (int) IW) {
|
||||
continue;
|
||||
}
|
||||
++j;
|
||||
|
||||
float v;
|
||||
if (src->type == GGML_TYPE_F32) {
|
||||
v = ((const float *) srow_bytes)[j];
|
||||
} else {
|
||||
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: res += v; break;
|
||||
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
++count;
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: drow[i] /= k; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
cdata += src->nb[1];
|
||||
drow += rs;
|
||||
drow[ow] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
|
|||
const int k0 = opts[1];
|
||||
const int s0 = opts[2];
|
||||
const int p0 = opts[3];
|
||||
GGML_ASSERT(p0 == 0); // padding not supported
|
||||
GGML_ASSERT(k0 == s0); // only s = k supported
|
||||
|
||||
ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
|
||||
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pool_2d
|
||||
|
|
@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
|
|||
}
|
||||
|
||||
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||
|
||||
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||
const int k0 = opts[1];
|
||||
const int k1 = opts[2];
|
||||
|
|
@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d(
|
|||
while (cdata < data_end) {
|
||||
for (int oy = 0; oy < py; ++oy) {
|
||||
float * const drow = dplane + oy * px;
|
||||
float * const out = drow;
|
||||
|
||||
for (int ox = 0; ox < px; ++ox) {
|
||||
float * const out = drow + ox;
|
||||
float res = 0;
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out = 0; break;
|
||||
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_AVG: res = 0; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
|
@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d(
|
|||
const int iy = offset1 + oy * s1;
|
||||
|
||||
for (int ky = 0; ky < k1; ++ky) {
|
||||
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
|
||||
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
|
||||
for (int kx = 0; kx < k0; ++kx) {
|
||||
int j = ix + kx;
|
||||
if (j < 0 || j >= src->ne[0]) continue;
|
||||
if (j < 0 || j >= src->ne[0]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
|
||||
case GGML_OP_POOL_AVG: res += srow_j; break;
|
||||
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: *out /= ka; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_AVG: res /= ka; break;
|
||||
case GGML_OP_POOL_MAX: break;
|
||||
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
out[ox] = res;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
||||
|
||||
const char * pool_str = "undefined";
|
||||
switch (op_pool) {
|
||||
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
|
||||
case GGML_OP_POOL_MAX: pool_str = "max"; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, sizeof(name), "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
|||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
|
|
|
|||
|
|
@ -1044,10 +1044,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
case GGML_OP_POOL_1D:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
|
|
|
|||
|
|
@ -928,6 +928,15 @@ typedef struct {
|
|||
int64_t np;
|
||||
} ggml_metal_kargs_pool_2d;
|
||||
|
||||
typedef struct {
|
||||
int32_t k0;
|
||||
int32_t s0;
|
||||
int32_t p0;
|
||||
int64_t IW;
|
||||
int64_t OW;
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_pool_1d;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
uint64_t nb01;
|
||||
|
|
|
|||
|
|
@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_cpy(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_POOL_1D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
|
||||
|
|
@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t * opts = op->op_params;
|
||||
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
||||
|
||||
const int32_t k0 = opts[1];
|
||||
const int32_t s0 = opts[2];
|
||||
const int32_t p0 = opts[3];
|
||||
|
||||
const int64_t IW = op->src[0]->ne[0];
|
||||
const int64_t OW = op->ne[0];
|
||||
|
||||
const int64_t np = ggml_nelements(op);
|
||||
|
||||
ggml_metal_kargs_pool_1d args_pool_1d = {
|
||||
/* .k0 = */ k0,
|
||||
/* .s0 = */ s0,
|
||||
/* .p0 = */ p0,
|
||||
/* .IW = */ IW,
|
||||
/* .OW = */ OW,
|
||||
/* .np = */ np
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
||||
const int ntg = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32(
|
|||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_pool_1d_max_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = -INFINITY;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
int j = base + ki;
|
||||
if (j < 0 || j >= args.IW){
|
||||
continue;
|
||||
}
|
||||
float v = src[src_off + j];
|
||||
acc = max(acc, v);
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = acc;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_1d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = 0.0f;
|
||||
int cnt = 0;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= args.IW) {
|
||||
continue;
|
||||
}
|
||||
acc += src[src_off + j];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
|
|
|
|||
|
|
@ -4838,6 +4838,8 @@ struct ggml_tensor * ggml_pool_1d(
|
|||
a->ne[2],
|
||||
a->ne[3],
|
||||
};
|
||||
GGML_ASSERT(ne[0] > 0);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { op, k0, s0, p0 };
|
||||
|
|
@ -4868,6 +4870,9 @@ struct ggml_tensor * ggml_pool_2d(
|
|||
a->ne[2],
|
||||
a->ne[3],
|
||||
};
|
||||
GGML_ASSERT(ne[0] > 0);
|
||||
GGML_ASSERT(ne[1] > 0);
|
||||
|
||||
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
|
||||
|
|
|
|||
|
|
@ -4679,6 +4679,37 @@ struct test_pool2d : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_POOL1D
|
||||
struct test_pool1d : public test_case {
|
||||
enum ggml_op_pool pool_type;
|
||||
const ggml_type type_input;
|
||||
const std::array<int64_t, 4> ne_input;
|
||||
const int k0;
|
||||
const int s0;
|
||||
const int p0;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0);
|
||||
}
|
||||
|
||||
test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
|
||||
ggml_type type_input = GGML_TYPE_F32,
|
||||
std::array<int64_t,4> ne_input = {10, 1, 1, 1},
|
||||
int k0 = 3, int s0 = 3, int p0 = 0)
|
||||
: pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
|
||||
ggml_set_param(input);
|
||||
ggml_set_name(input, "input");
|
||||
|
||||
ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CONV_TRANSPOSE_1D
|
||||
struct test_conv_transpose_1d : public test_case {
|
||||
const std::array<int64_t, 4> ne_input;
|
||||
|
|
@ -7058,6 +7089,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
for (ggml_type type_input : {GGML_TYPE_F32}) {
|
||||
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
|
||||
for (int k0 : {1, 3}) {
|
||||
for (int s0 : {1, 2}) {
|
||||
for (int p0 : {0, 1}) {
|
||||
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10, 3, 2, 1 }, k0, s0, p0));
|
||||
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11, 1, 3, 2 }, k0, s0, p0));
|
||||
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 2, 1, 3 }, k0, s0, p0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// >4GB im2col destination. Too slow to run by default.
|
||||
// Test cases taken from Wan2.1 T2V 1.3B.
|
||||
|
|
|
|||
Loading…
Reference in New Issue