Work on set rows
This commit is contained in:
parent
6a6135cc85
commit
b2dbfcdcb1
|
|
@ -179,7 +179,6 @@ jobs:
|
||||||
- name: Test
|
- name: Test
|
||||||
id: cmake_test
|
id: cmake_test
|
||||||
run: |
|
run: |
|
||||||
export LLAMA_SET_ROWS=0
|
|
||||||
cd build
|
cd build
|
||||||
ctest -L main --verbose --timeout 900
|
ctest -L main --verbose --timeout 900
|
||||||
|
|
||||||
|
|
@ -438,7 +437,6 @@ jobs:
|
||||||
- name: Test
|
- name: Test
|
||||||
id: cmake_test
|
id: cmake_test
|
||||||
run: |
|
run: |
|
||||||
export LLAMA_SET_ROWS=0
|
|
||||||
cd build
|
cd build
|
||||||
# This is using llvmpipe and runs slower than other backends
|
# This is using llvmpipe and runs slower than other backends
|
||||||
ctest -L main --verbose --timeout 3600
|
ctest -L main --verbose --timeout 3600
|
||||||
|
|
|
||||||
|
|
@ -495,9 +495,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||||
(uint32_t) src->ne[1],
|
(uint32_t) src->ne[1],
|
||||||
(uint32_t) src->ne[2],
|
(uint32_t) src->ne[2],
|
||||||
(uint32_t) src->ne[3],
|
(uint32_t) src->ne[3],
|
||||||
// broadcast shape of idx
|
// Shape of idx
|
||||||
(uint32_t) (src->ne[2] / idx->ne[1]),
|
(uint32_t) (idx->ne[1]),
|
||||||
(uint32_t) (src->ne[3] / idx->ne[2])
|
(uint32_t) (idx->ne[2])
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
|
@ -512,18 +512,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||||
{ .binding = 2,
|
{ .binding = 2,
|
||||||
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
.buffer = ggml_backend_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
.offset = ggml_backend_webgpu_tensor_offset(dst),
|
||||||
.size = ggml_nbytes(dst) },
|
.size = ggml_nbytes(dst) }
|
||||||
{ .binding = 3,
|
|
||||||
.buffer = ctx->debug_dev_buf,
|
|
||||||
.offset = 0,
|
|
||||||
.size = ctx->debug_dev_buf.GetSize() }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
||||||
ggml_backend_webgpu_submit_queue(ctx);
|
ggml_backend_webgpu_submit_queue(ctx);
|
||||||
ggml_backend_webgpu_debug(ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> idx: array<u32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<f16>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_idx: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_idx0: u32,
|
||||||
|
stride_idx1: u32,
|
||||||
|
stride_idx2: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Shape of src
|
||||||
|
ne0: u32,
|
||||||
|
n_rows: u32,
|
||||||
|
ne2: u32,
|
||||||
|
ne3: u32,
|
||||||
|
|
||||||
|
// Shape of idx
|
||||||
|
idx1: u32,
|
||||||
|
idx2: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var i = gid.x;
|
||||||
|
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||||
|
let i_dst3 = i / (params.ne2 * 3);
|
||||||
|
|
||||||
|
i = i % (params.ne2 * params.n_rows);
|
||||||
|
let i_src2 = i / params.n_rows;
|
||||||
|
let i_src1 = i % params.n_rows;
|
||||||
|
|
||||||
|
let i_idx2 = i_src3 % params.idx2;
|
||||||
|
let i_idx1 = i_src2 % params.idx1;
|
||||||
|
let i_idx0 = i_src1;
|
||||||
|
|
||||||
|
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
|
||||||
|
|
||||||
|
let idx_high_val = idx[idx_high];
|
||||||
|
let idx_low_val = idx[idx_high + 1];
|
||||||
|
|
||||||
|
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||||
|
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||||
|
|
||||||
|
for (var i: u32 = 0; i < params.ne0; i++) {
|
||||||
|
dst[i_dst_row + i] = f16(src[i_src_row + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1213,12 +1213,12 @@ struct test_case {
|
||||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||||
if (err > ud->max_err) {
|
if (err > ud->max_err) {
|
||||||
//printf("Backends %s and %s mismatch: ", bn1, bn2);
|
//printf("Backends %s and %s mismatch: ", bn1, bn2);
|
||||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||||
for (int i = 0; i < (int) f1.size(); i++) {
|
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||||
printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||||
}
|
//}
|
||||||
printf("\n");
|
//printf("\n");
|
||||||
exit(1);
|
//exit(1);
|
||||||
ud->ok = false;
|
ud->ok = false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue