ggml : add new GGML_OP_WHERE_ID (akin to torch where but using indices)

This commit is contained in:
Stanisław Szymczyk 2026-03-15 21:58:49 +01:00
parent 3eb340ed4b
commit 08dc7fd9d9
8 changed files with 203 additions and 2 deletions

View File

@ -558,6 +558,7 @@ extern "C" {
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_HADAMARD,
GGML_OP_WHERE_ID,
GGML_OP_UNARY,
@ -2479,6 +2480,12 @@ extern "C" {
struct ggml_tensor * a,
int n);
GGML_API struct ggml_tensor * ggml_where_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2029,6 +2029,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_hadamard(params, tensor);
} break;
case GGML_OP_WHERE_ID:
{
ggml_compute_forward_where_id(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2352,6 +2356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
case GGML_OP_HADAMARD:
case GGML_OP_WHERE_ID:
{
n_tasks = n_threads;
} break;

View File

@ -11256,3 +11256,81 @@ void ggml_compute_forward_hadamard(
}
}
}
// ggml_compute_forward_where_id
static void ggml_compute_forward_where_id_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
const float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 );
const int32_t * ids_ptr = (int32_t *) ((char *) src2->data + i3*nb23 + i2*nb22 + i1*nb21);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
// copy whole row from src1
ggml_vec_cpy_f32(ne00, dst_ptr, src1_ptr);
// copy only values from src0 indicated by indices in src2
for (int j = 0; j < ne20; ++j) {
int id = ids_ptr[j];
GGML_ASSERT(id >= 0 && id < ne00);
dst_ptr[id] = src0_ptr[id];
}
}
}
void ggml_compute_forward_where_id(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_where_id_f32(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_where_id: %s", ggml_type_name(src0->type));
}
}
}

View File

@ -104,6 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_where_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -62,6 +62,7 @@
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/fill.cuh"
#include "ggml-cuda/hadamard.cuh"
#include "ggml-cuda/where-id.cuh"
#include "ggml.h"
#include <algorithm>
@ -2775,6 +2776,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_HADAMARD:
ggml_cuda_op_hadamard(ctx, dst);
break;
case GGML_OP_WHERE_ID:
ggml_cuda_op_where_id(ctx, dst);
break;
default:
return false;
}
@ -5016,6 +5020,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_SOLVE_TRI:
case GGML_OP_WHERE_ID:
return true;
case GGML_OP_HADAMARD:
return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;

View File

@ -0,0 +1,77 @@
#include "where-id.cuh"
static __global__ void where_id_kernel(
const float * src0, const int32_t * src1, float * dst,
int64_t ne10, int64_t ne11, int64_t ne12,
size_t nb1, size_t nb2,
size_t nb01, size_t nb02,
size_t nb11, size_t nb12
) {
const int64_t total_blocks = ne11 * ne12;
for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) {
const int64_t i1 = block_idx % ne11;
const int64_t i2 = block_idx / ne11;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12);
for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) {
const int32_t id = src1_row[i0];
dst_row[id] = src0_row[id];
}
}
}
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(src2));
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
// step 1 - copy whole src1 to dst
cudaStream_t main_stream = ctx.stream();
char * dst_ddc = (char *) dst->data;
char * src1_ddc = (char *) src1->data;
CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src1_ddc, ggml_nbytes(src1), cudaMemcpyDeviceToDevice, main_stream));
// step 2 - copy elements from src0 indicated by ids to dst
const float * src0_d = (const float *) src0->data;
const int32_t * src2_d = (const int32_t *) src2->data;
float * dst_d = (float *) dst->data;
int threads = std::min((int) ne20, 768); // ids
int64_t total_blocks = ne21 * ne22;
int blocks = (int) std::min((int64_t) 65535, total_blocks);
where_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src2_d, dst_d,
ne20, ne21, ne22,
nb1, nb2,
nb01, nb02,
nb21, nb22
);
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1033,6 +1033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SOLVE_TRI",
"GATED_DELTA_NET",
"HADAMARD",
"WHERE_ID",
"UNARY",
@ -1050,7 +1051,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1144,6 +1145,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"A X = B, A triangular, solve X",
"gated_delta_net(q, k, v, g, beta, s)",
"hadamard(x)",
"where_id(x,y,ids)",
"unary(x)",
@ -1161,7 +1163,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static_assert(GGML_OP_COUNT == 98, "GGML_OP_COUNT != 98");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6201,6 +6203,29 @@ struct ggml_tensor * ggml_hadamard(
return result;
}
// ggml_where_id
struct ggml_tensor * ggml_where_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[1] == ids->ne[1]);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
result->op = GGML_OP_WHERE_ID;
result->src[0] = a;
result->src[1] = b;
result->src[2] = ids;
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {