ggml : add circular tiling support to pad, for Vulkan, CUDA, and CPU (used for making seamless textures) (#16985)
* Feat: Added vulkan circular tiling support * Feat: Added cpu circular * Feat: Added cuda kernels * Added tests * Added tests * Removed non-pad operations * Removed unneded changes * removed backend non pad tests * Update test-backend-ops.cpp * Fixed comment on pad test * removed trailing whitespace * Removed unneded test in test-backend-ops * Removed removed test from calls * Update ggml/src/ggml-vulkan/vulkan-shaders/pad.comp Co-authored-by: Ruben Ortlam <picard12@live.de> * Fixed alignment * Formatting Co-authored-by: Aman Gupta <amangupta052@gmail.com> * Format pad * Format * Clang format * format * format * don't change so much stuff * clang format and update to bool * fix duplicates * don't need to fix the padding * make circular bool * duplicate again * rename vulkan to wrap around * Don't need indent * moved to const expr * removed unneded extra line break * More readable method calls * Minor wording changes * Added final newline * Update ggml/include/ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml/include/ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Added circular pad ext tests * Gate non circular pad devices * Cleaned gating of non-circular pad devices --------- Co-authored-by: Phylliida <phylliidadev@gmail.com> Co-authored-by: Ruben Ortlam <picard12@live.de> Co-authored-by: Aman Gupta <amangupta052@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
f334b79494
commit
09c7c50e64
|
|
@ -2196,6 +2196,15 @@ extern "C" {
|
||||||
int p2,
|
int p2,
|
||||||
int p3);
|
int p3);
|
||||||
|
|
||||||
|
// pad each dimension with values on the other side of the torus (looping around)
|
||||||
|
GGML_API struct ggml_tensor * ggml_pad_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int p2,
|
||||||
|
int p3);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_pad_ext(
|
GGML_API struct ggml_tensor * ggml_pad_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
@ -2209,6 +2218,19 @@ extern "C" {
|
||||||
int rp3
|
int rp3
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// pad each dimension with values on the other side of the torus (looping around)
|
||||||
|
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int lp0,
|
||||||
|
int rp0,
|
||||||
|
int lp1,
|
||||||
|
int rp1,
|
||||||
|
int lp2,
|
||||||
|
int rp2,
|
||||||
|
int lp3,
|
||||||
|
int rp3);
|
||||||
|
|
||||||
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
||||||
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
|
||||||
|
|
@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
return ggml_get_op_params_i32(op, 8) == 0;
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
|
|
||||||
|
|
@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
||||||
ggml_compute_forward_mul_mat(params, &dst);
|
ggml_compute_forward_mul_mat(params, &dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
||||||
|
return (coord + size) % size; // adding size avoids negative number weirdness
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_conv_2d
|
// ggml_compute_forward_conv_2d
|
||||||
|
|
||||||
|
|
||||||
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
||||||
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
||||||
const ggml_tensor * src, // [W, H, C, N]
|
const ggml_tensor * src, // [W, H, C, N]
|
||||||
|
|
@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
||||||
|
|
||||||
// ggml_compute_forward_pad
|
// ggml_compute_forward_pad
|
||||||
|
|
||||||
|
template<bool circular_t>
|
||||||
static void ggml_compute_forward_pad_f32(
|
static void ggml_compute_forward_pad_f32(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
@ -7615,13 +7621,29 @@ static void ggml_compute_forward_pad_f32(
|
||||||
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
||||||
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
||||||
|
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||||
|
// circular means wrap around on a torus, so x and y loop around
|
||||||
|
if constexpr (circular_t) {
|
||||||
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||||
|
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
||||||
|
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
||||||
|
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
||||||
|
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
||||||
|
|
||||||
|
const int64_t src_idx =
|
||||||
|
src_i3*nb03 +
|
||||||
|
src_i2*nb02 +
|
||||||
|
src_i1*nb01 +
|
||||||
|
src_i0*nb00;
|
||||||
|
|
||||||
|
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||||
|
dst_ptr[dst_idx] = *src_ptr;
|
||||||
|
} else {
|
||||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||||
|
|
@ -7637,18 +7659,23 @@ static void ggml_compute_forward_pad_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_compute_forward_pad(
|
void ggml_compute_forward_pad(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_pad_f32(params, dst);
|
if (circular) {
|
||||||
|
ggml_compute_forward_pad_f32<true>(params, dst);
|
||||||
|
} else {
|
||||||
|
ggml_compute_forward_pad_f32<false>(params, dst);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,17 @@
|
||||||
#include "pad.cuh"
|
#include "pad.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||||
|
// + size ensures negatives are handled properly
|
||||||
|
return (coord + size) % size;
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void pad_f32(const float * src, float * dst,
|
static __global__ void pad_f32(const float * src, float * dst,
|
||||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||||
|
const bool circular) {
|
||||||
// blockIdx.z: i3*ne2+i2
|
// blockIdx.z: i3*ne2+i2
|
||||||
// blockIdx.y: i1
|
// blockIdx.y: i1
|
||||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||||
|
|
@ -12,15 +20,15 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||||
int i1 = blockIdx.y;
|
int i1 = blockIdx.y;
|
||||||
int i2 = blockIdx.z % ne2;
|
int i2 = blockIdx.z % ne2;
|
||||||
int i3 = blockIdx.z / ne2;
|
int i3 = blockIdx.z / ne2;
|
||||||
|
|
||||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// operation
|
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
|
||||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
||||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
if (!circular) {
|
||||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
|
||||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||||
const int64_t i00 = i0 - lp0;
|
const int64_t i00 = i0 - lp0;
|
||||||
const int64_t i01 = i1 - lp1;
|
const int64_t i01 = i1 - lp1;
|
||||||
|
|
@ -30,43 +38,66 @@ static __global__ void pad_f32(const float * src, float * dst,
|
||||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||||
|
|
||||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||||
|
|
||||||
dst[dst_idx] = src[src_idx];
|
dst[dst_idx] = src[src_idx];
|
||||||
} else {
|
} else {
|
||||||
dst[dst_idx] = 0.0f;
|
dst[dst_idx] = 0.0f;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
// circular means on a torus, so x and y wrap around
|
||||||
|
else {
|
||||||
|
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||||
|
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||||
|
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||||
|
const int64_t ne03 = ne3 - lp3 - rp3;
|
||||||
|
|
||||||
|
const int64_t i00 = wrap_around(i0 - lp0, ne00);
|
||||||
|
const int64_t i01 = wrap_around(i1 - lp1, ne01);
|
||||||
|
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||||
|
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||||
|
|
||||||
|
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||||
|
|
||||||
|
dst[dst_idx] = src[src_idx];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void pad_f32_cuda(const float * src, float * dst,
|
static void pad_f32_cuda(const float * src, float * dst,
|
||||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||||
|
const bool circular, cudaStream_t stream) {
|
||||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||||
|
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||||
|
ne0, ne1, ne2, ne3, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
|
||||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
|
||||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
|
||||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
|
||||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
|
||||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||||
|
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||||
|
|
||||||
pad_f32_cuda(src0_d, dst_d,
|
pad_f32_cuda(src0_d, dst_d,
|
||||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
(bool) circular, stream);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1037,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
return op->src[0]->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||||
case GGML_OP_PAD_REFLECT_1D:
|
case GGML_OP_PAD_REFLECT_1D:
|
||||||
|
|
|
||||||
|
|
@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_UPSCALE: {
|
case GGML_OP_UPSCALE: {
|
||||||
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
|
||||||
|
|
|
||||||
|
|
@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||||
|
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
|
|
||||||
|
|
@ -1050,6 +1050,7 @@ struct vk_op_pad_push_constants {
|
||||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||||
uint32_t misalign_offsets;
|
uint32_t misalign_offsets;
|
||||||
|
uint32_t circular;
|
||||||
|
|
||||||
uint32_t lp0; uint32_t rp0;
|
uint32_t lp0; uint32_t rp0;
|
||||||
uint32_t lp1; uint32_t rp1;
|
uint32_t lp1; uint32_t rp1;
|
||||||
|
|
@ -1092,6 +1093,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
|
||||||
p.rp2 = dst->op_params[5];
|
p.rp2 = dst->op_params[5];
|
||||||
p.lp3 = dst->op_params[6];
|
p.lp3 = dst->op_params[6];
|
||||||
p.rp3 = dst->op_params[7];
|
p.rp3 = dst->op_params[7];
|
||||||
|
p.circular = dst->op_params[8];
|
||||||
|
|
||||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
|
||||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||||
uint misalign_offsets;
|
uint misalign_offsets;
|
||||||
|
uint circular;
|
||||||
|
|
||||||
uint lp0; uint rp0;
|
uint lp0; uint rp0;
|
||||||
uint lp1; uint rp1;
|
uint lp1; uint rp1;
|
||||||
|
|
@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
|
||||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||||
|
|
||||||
|
uint wrap_around(int coord, uint size) {
|
||||||
|
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
|
||||||
|
}
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
|
@ -40,10 +45,20 @@ void main() {
|
||||||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||||
|
|
||||||
|
if (p.circular != 0u) {
|
||||||
|
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
|
||||||
|
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
|
||||||
|
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
|
||||||
|
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
|
||||||
|
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
|
||||||
|
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
|
||||||
|
} else {
|
||||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||||
|
|
||||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4947,6 +4947,18 @@ struct ggml_tensor * ggml_pad(
|
||||||
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_pad_circular
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_pad_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int p0,
|
||||||
|
int p1,
|
||||||
|
int p2,
|
||||||
|
int p3) {
|
||||||
|
return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_pad_ext(
|
struct ggml_tensor * ggml_pad_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
@ -4973,6 +4985,7 @@ struct ggml_tensor * ggml_pad_ext(
|
||||||
ggml_set_op_params_i32(result, 5, rp2);
|
ggml_set_op_params_i32(result, 5, rp2);
|
||||||
ggml_set_op_params_i32(result, 6, lp3);
|
ggml_set_op_params_i32(result, 6, lp3);
|
||||||
ggml_set_op_params_i32(result, 7, rp3);
|
ggml_set_op_params_i32(result, 7, rp3);
|
||||||
|
ggml_set_op_params_i32(result, 8, 0); // not circular by default
|
||||||
|
|
||||||
|
|
||||||
result->op = GGML_OP_PAD;
|
result->op = GGML_OP_PAD;
|
||||||
|
|
@ -4981,6 +4994,25 @@ struct ggml_tensor * ggml_pad_ext(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_pad_ext_circular
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_pad_ext_circular(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int lp0,
|
||||||
|
int rp0,
|
||||||
|
int lp1,
|
||||||
|
int rp1,
|
||||||
|
int lp2,
|
||||||
|
int rp2,
|
||||||
|
int lp3,
|
||||||
|
int rp3
|
||||||
|
) {
|
||||||
|
struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
|
ggml_set_op_params_i32(result, 8, 1); // circular
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_pad_reflect_1d
|
// ggml_pad_reflect_1d
|
||||||
|
|
||||||
struct ggml_tensor * ggml_pad_reflect_1d(
|
struct ggml_tensor * ggml_pad_reflect_1d(
|
||||||
|
|
|
||||||
|
|
@ -5604,21 +5604,24 @@ struct test_pad : public test_case {
|
||||||
const std::array<int64_t, 4> ne_a;
|
const std::array<int64_t, 4> ne_a;
|
||||||
const int pad_0;
|
const int pad_0;
|
||||||
const int pad_1;
|
const int pad_1;
|
||||||
|
const bool circular;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_pad(ggml_type type = GGML_TYPE_F32,
|
test_pad(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
||||||
int pad_0 = 1, int pad_1 = 1)
|
int pad_0 = 1, int pad_1 = 1, bool circular = false)
|
||||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
ggml_tensor * out = circular
|
||||||
|
? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)
|
||||||
|
: ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -5638,17 +5641,19 @@ struct test_pad_ext : public test_case {
|
||||||
const int lp3;
|
const int lp3;
|
||||||
const int rp3;
|
const int rp3;
|
||||||
const bool v;
|
const bool v;
|
||||||
|
const bool circular;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR11(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v);
|
return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
test_pad_ext(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
|
||||||
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
|
||||||
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
|
||||||
bool v = false)
|
bool v = false, bool circular = false)
|
||||||
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), v(v) {}
|
: type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
|
||||||
|
v(v), circular(circular) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
|
@ -5659,7 +5664,9 @@ struct test_pad_ext : public test_case {
|
||||||
ggml_set_name(a, "view of a");
|
ggml_set_name(a, "view of a");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
ggml_tensor * out = circular
|
||||||
|
? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
|
||||||
|
: ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -7782,6 +7789,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||||
test_cases.emplace_back(new test_acc());
|
test_cases.emplace_back(new test_acc());
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
|
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
|
||||||
test_cases.emplace_back(new test_pad_ext());
|
test_cases.emplace_back(new test_pad_ext());
|
||||||
test_cases.emplace_back(new test_pad_reflect_1d());
|
test_cases.emplace_back(new test_pad_reflect_1d());
|
||||||
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
|
||||||
|
|
@ -7829,8 +7837,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
||||||
|
|
||||||
for (bool v : {false, true}) {
|
for (bool v : {false, true}) {
|
||||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
for (bool circular : {false, true}) {
|
||||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
|
||||||
|
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue