#include "repeat_back.hpp" #include "common.hpp" #define GGML_ASSERT_TENSOR_FITS_INT(t) \ GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX) void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); const float * src0_dd = (const float *) dst->src[0]->data; float * dst_dd = (float *) dst->data; GGML_ASSERT_TENSOR_FITS_INT(dst); GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]); const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], ne03 = dst->src[0]->ne[3]; const int nr0 = ne00 / ne0; const int nr1 = ne01 / ne1; const int nr2 = ne02 / ne2; const int nr3 = ne03 / ne3; const int nb0 = dst->src[0]->nb[0]; const int nb1 = dst->src[0]->nb[1]; const int nb2 = dst->src[0]->nb[2]; const int nb3 = dst->src[0]->nb[3]; const char * base = (const char *) src0_dd; const size_t total = (size_t) ne0 * ne1 * ne2 * ne3; constexpr int BLOCK_SIZE = 256; const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; const float inv_ne0 = 1.0f / ne0; const float inv_ne_01 = 1.0f / (ne0 * ne1); const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2); const int repeat_count = nr0 * nr1 * nr2 * nr3; queue_ptr stream = ctx.stream(); stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { const size_t i = item_ct1.get_global_linear_id(); if (i >= total) { return; } const int i3 = (int) (i * inv_ne_012); const int i2 = (int) (i * inv_ne_01) - i3 * ne2; const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1; const int i0 = i - (int) (i * inv_ne0) * ne0; int j0 = 0, j1 = 0, j2 = 0, j3 = 0; float acc = 0.0f; for (int j = 0; j < repeat_count; ++j) { const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 + (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3); acc += *ptr; int carry = (++j0 >= nr0); j0 -= carry * nr0; carry = (carry && (++j1 >= nr1)); j1 -= carry * nr1; carry = (carry && (++j2 >= nr2)); j2 -= carry * nr2; j3 += carry; } dst_dd[i] = acc; }); }