Work on set rows

This commit is contained in:
Reese Levine 2025-08-05 16:33:15 -07:00
parent 6a6135cc85
commit b2dbfcdcb1
4 changed files with 83 additions and 17 deletions

View File

@ -179,7 +179,6 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
ctest -L main --verbose --timeout 900
@ -438,7 +437,6 @@ jobs:
- name: Test
id: cmake_test
run: |
export LLAMA_SET_ROWS=0
cd build
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 3600

View File

@ -495,9 +495,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// broadcast shape of idx
(uint32_t) (src->ne[2] / idx->ne[1]),
(uint32_t) (src->ne[3] / idx->ne[2])
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2])
};
std::vector<wgpu::BindGroupEntry> entries = {
@ -512,18 +512,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = ggml_backend_webgpu_tensor_offset(dst),
.size = ggml_nbytes(dst) },
{ .binding = 3,
.buffer = ctx->debug_dev_buf,
.offset = 0,
.size = ctx->debug_dev_buf.GetSize() }
.size = ggml_nbytes(dst) }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
ggml_backend_webgpu_submit_queue(ctx);
ggml_backend_webgpu_debug(ctx);
}
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {

View File

@ -0,0 +1,73 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f16>;
struct Params {
offset_src: u32, // in elements
offset_idx: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_idx0: u32,
stride_idx1: u32,
stride_idx2: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Shape of src
ne0: u32,
n_rows: u32,
ne2: u32,
ne3: u32,
// Shape of idx
idx1: u32,
idx2: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
}
var i = gid.x;
let i_src3 = i / (params.ne2 * params.n_rows);
let i_dst3 = i / (params.ne2 * 3);
i = i % (params.ne2 * params.n_rows);
let i_src2 = i / params.n_rows;
let i_src1 = i % params.n_rows;
let i_idx2 = i_src3 % params.idx2;
let i_idx1 = i_src2 % params.idx1;
let i_idx0 = i_src1;
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
let idx_high_val = idx[idx_high];
let idx_low_val = idx[idx_high + 1];
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
for (var i: u32 = 0; i < params.ne0; i++) {
dst[i_dst_row + i] = f16(src[i_src_row + i]);
}
}

View File

@ -1213,12 +1213,12 @@ struct test_case {
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
//printf("Backends %s and %s mismatch: ", bn1, bn2);
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
for (int i = 0; i < (int) f1.size(); i++) {
printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
}
printf("\n");
exit(1);
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
//for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//}
//printf("\n");
//exit(1);
ud->ok = false;
}
return true;