Work on set rows
This commit is contained in:
parent
6a6135cc85
commit
b2dbfcdcb1
|
|
@ -179,7 +179,6 @@ jobs:
|
|||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
export LLAMA_SET_ROWS=0
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
|
|
@ -438,7 +437,6 @@ jobs:
|
|||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
export LLAMA_SET_ROWS=0
|
||||
cd build
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
|
|
|
|||
|
|
@ -495,9 +495,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
(uint32_t) src->ne[1],
|
||||
(uint32_t) src->ne[2],
|
||||
(uint32_t) src->ne[3],
|
||||
// broadcast shape of idx
|
||||
(uint32_t) (src->ne[2] / idx->ne[1]),
|
||||
(uint32_t) (src->ne[3] / idx->ne[2])
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]),
|
||||
(uint32_t) (idx->ne[2])
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
|
|
@ -512,18 +512,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
{ .binding = 2,
|
||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||
.size = ggml_nbytes(dst) },
|
||||
{ .binding = 3,
|
||||
.buffer = ctx->debug_dev_buf,
|
||||
.offset = 0,
|
||||
.size = ctx->debug_dev_buf.GetSize() }
|
||||
.size = ggml_nbytes(dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_submit_queue(ctx);
|
||||
ggml_backend_webgpu_debug(ctx);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
enable f16;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f16>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_idx: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_idx0: u32,
|
||||
stride_idx1: u32,
|
||||
stride_idx2: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Shape of src
|
||||
ne0: u32,
|
||||
n_rows: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
|
||||
// Shape of idx
|
||||
idx1: u32,
|
||||
idx2: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||
let i_dst3 = i / (params.ne2 * 3);
|
||||
|
||||
i = i % (params.ne2 * params.n_rows);
|
||||
let i_src2 = i / params.n_rows;
|
||||
let i_src1 = i % params.n_rows;
|
||||
|
||||
let i_idx2 = i_src3 % params.idx2;
|
||||
let i_idx1 = i_src2 % params.idx1;
|
||||
let i_idx0 = i_src1;
|
||||
|
||||
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
|
||||
|
||||
let idx_high_val = idx[idx_high];
|
||||
let idx_low_val = idx[idx_high + 1];
|
||||
|
||||
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||
|
||||
for (var i: u32 = 0; i < params.ne0; i++) {
|
||||
dst[i_dst_row + i] = f16(src[i_src_row + i]);
|
||||
}
|
||||
}
|
||||
|
|
@ -1213,12 +1213,12 @@ struct test_case {
|
|||
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||
if (err > ud->max_err) {
|
||||
//printf("Backends %s and %s mismatch: ", bn1, bn2);
|
||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||
for (int i = 0; i < (int) f1.size(); i++) {
|
||||
printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
}
|
||||
printf("\n");
|
||||
exit(1);
|
||||
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
//exit(1);
|
||||
ud->ok = false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
|||
Loading…
Reference in New Issue