ggml : added inplace version of GGML_OP_SCATTER and tests for this OP

This commit is contained in:
Stanisław Szymczyk 2026-03-24 20:32:45 +01:00
parent 9b0a4eea57
commit 0ee5d80ed3
5 changed files with 111 additions and 13 deletions

View File

@ -2486,6 +2486,12 @@ extern "C" {
struct ggml_tensor * ids,
float c);
GGML_API struct ggml_tensor * ggml_scatter_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -11265,7 +11265,9 @@ static void ggml_compute_forward_scatter_f32(
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float c = ggml_get_op_params_f32(dst, 0);
const bool inplace = ggml_get_op_params_i32(dst, 1);
GGML_ASSERT(ggml_are_same_shape(src0, dst));
@ -11303,7 +11305,9 @@ static void ggml_compute_forward_scatter_f32(
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
// copy whole row from src0
ggml_vec_cpy_f32(ne00, dst_ptr, src0_ptr);
if (!inplace) {
ggml_vec_cpy_f32(ne00, dst_ptr, src0_ptr);
}
// set dst elements indicated by indices in src1 to c
for (int j = 0; j < ne10; ++j) {

View File

@ -44,21 +44,23 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
float c;
memcpy(&c, (float *) dst->op_params + 0, sizeof(float));
float c = ggml_get_op_params_f32(dst, 0);
bool inplace = ggml_get_op_params_i32(dst, 1);
// step 1 - copy whole src0 to dst
cudaStream_t main_stream = ctx.stream();
char * dst_ddc = (char *) dst->data;
char * src0_ddc = (char *) src0->data;
if (!inplace) {
cudaStream_t main_stream = ctx.stream();
char * dst_ddc = (char *) dst->data;
char * src0_ddc = (char *) src0->data;
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
// step 2 - set elements in dst indicated by ids to c
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
int threads = std::min((int) ne10, 768); // ids
int threads = std::min((int) ne10, 512); // ids
int64_t total_blocks = ne11 * ne12 * ne13;
int blocks = (int) std::min((int64_t) 65535, total_blocks);

View File

@ -6205,20 +6205,21 @@ struct ggml_tensor * ggml_hadamard(
// ggml_scatter
struct ggml_tensor * ggml_scatter(
static struct ggml_tensor * ggml_scatter_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c) {
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[1] == ids->ne[1]);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
float params[1] = { c };
ggml_set_op_params(result, &params, sizeof(params));
ggml_set_op_params_f32(result, 0, c);
ggml_set_op_params_i32(result, 1, inplace ? 1 : 0);
result->op = GGML_OP_SCATTER;
result->src[0] = a;
@ -6227,6 +6228,22 @@ struct ggml_tensor * ggml_scatter(
return result;
}
struct ggml_tensor * ggml_scatter(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c) {
return ggml_scatter_impl(ctx, a, ids, c, false);
}
struct ggml_tensor * ggml_scatter_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * ids,
float c) {
return ggml_scatter_impl(ctx, a, ids, c, true);
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {

View File

@ -6648,6 +6648,65 @@ struct test_diag : public test_case {
}
};
// GGML_OP_SCATTER
struct test_scatter : public test_case {
const ggml_type type_a;
const ggml_type type_ids;
const std::array<int64_t, 4> ne_a;
const std::array<int64_t, 4> ne_ids;
float c;
bool inplace;
std::string vars() override {
return VARS_TO_STR6(type_a, type_ids, ne_a, ne_ids, c, inplace);
}
test_scatter(ggml_type type_a = GGML_TYPE_F32,
ggml_type type_ids = GGML_TYPE_I32,
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
std::array<int64_t, 4> ne_ids = {3, 10, 10, 10},
float c = 2.0f,
bool inplace = false)
: type_a(type_a), type_ids(type_ids), ne_a(ne_a), ne_ids(ne_ids), c(c), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type_a, 4, ne_a.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * ids = ggml_new_tensor(ctx, type_ids, 4, ne_ids.data());
ggml_set_param(ids);
ggml_set_name(ids, "ids");
ggml_tensor * out;
if (inplace) {
out = ggml_scatter_inplace(ctx, a, ids, c);
} else {
out = ggml_scatter(ctx, a, ids, c);
}
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// ids
const int num_pos_ids = ggml_nelements(t);
std::vector<int32_t> data(num_pos_ids);
for (int i = 0; i < num_pos_ids; i++) {
data[i] = rand() % ne_a[0];
}
ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
enum llm_norm_type {
LLM_NORM,
@ -8474,6 +8533,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_falcon(2));
#endif
// scatter
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false));
return test_cases;
}
#ifdef _MSC_VER
@ -8730,6 +8795,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
// scatter
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false));
return test_cases;
}