cuda : extend GGML_OP_PAD to work with non-cont src0 (#19429)
* cuda : extend GGML_OP_PAD to work with non-cont src0 * tests : add permuted pad
This commit is contained in:
parent
98e57ca422
commit
a0d585537c
|
|
@ -7629,8 +7629,7 @@ static void ggml_compute_forward_pad_f32(
|
|||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
|
|
|||
|
|
@ -4834,8 +4834,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_PAD:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_PAD:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
|||
return (coord + size) % size;
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
|
|
@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
|
|||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
|
|
@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
|
|||
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
ne0, ne1, ne2, ne3, circular);
|
||||
}
|
||||
|
|
@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||
|
|
@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
const size_t s00 = nb00 / ggml_type_size(src0->type);
|
||||
const size_t s01 = nb01 / ggml_type_size(src0->type);
|
||||
const size_t s02 = nb02 / ggml_type_size(src0->type);
|
||||
const size_t s03 = nb03 / ggml_type_size(src0->type);
|
||||
|
||||
pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(bool) circular, stream);
|
||||
|
|
|
|||
|
|
@ -5894,33 +5894,36 @@ struct test_pad_ext : public test_case {
|
|||
const int rp2;
|
||||
const int lp3;
|
||||
const int rp3;
|
||||
const bool v;
|
||||
const int tfrm; // 0 - none, 1 - non-cont, 2 - perm
|
||||
const bool circular;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
|
||||
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular);
|
||||
}
|
||||
|
||||
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
||||
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
||||
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
||||
bool v = false, bool circular = false)
|
||||
int tfrm = 0, bool circular = false)
|
||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
|
||||
v(v), circular(circular) {}
|
||||
tfrm(tfrm), circular(circular) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (v) {
|
||||
if (tfrm == 1) {
|
||||
a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
} else if (tfrm == 2) {
|
||||
a = ggml_permute(ctx, a, 2, 1, 0, 3);
|
||||
ggml_set_name(a, "permuted a");
|
||||
}
|
||||
|
||||
ggml_tensor * out = circular
|
||||
? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
|
||||
: ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
: ggml_pad_ext (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -8198,10 +8201,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
|
||||
|
||||
for (bool v : {false, true}) {
|
||||
for (int tfrm : {0, 1, 2}) {
|
||||
for (bool circular : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue