Merge 17c99108e6 into 9db77a020c
This commit is contained in:
commit
a984d87457
|
|
@ -248,6 +248,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** CONV 2D */
|
||||||
|
struct ggml_webgpu_conv2d_pipeline_key {
|
||||||
|
ggml_type weight_type;
|
||||||
|
ggml_type input_type;
|
||||||
|
ggml_type output_type;
|
||||||
|
|
||||||
|
bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
|
||||||
|
return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_webgpu_conv2d_pipeline_key_hash {
|
||||||
|
size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
|
||||||
|
size_t seed = 0;
|
||||||
|
ggml_webgpu_hash_combine(seed, key.weight_type);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.input_type);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.output_type);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** Gated Delta Net **/
|
/** Gated Delta Net **/
|
||||||
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
||||||
int type;
|
int type;
|
||||||
|
|
@ -831,6 +852,8 @@ class ggml_webgpu_shader_lib {
|
||||||
rope_pipelines;
|
rope_pipelines;
|
||||||
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
|
||||||
soft_max_pipelines;
|
soft_max_pipelines;
|
||||||
|
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
||||||
|
conv2d_pipelines;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||||
|
|
@ -1115,8 +1138,7 @@ class ggml_webgpu_shader_lib {
|
||||||
std::string type_upper = type_str;
|
std::string type_upper = type_str;
|
||||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||||
|
|
||||||
switch (key.src_type)
|
switch (key.src_type) {
|
||||||
{
|
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
|
@ -1136,9 +1158,9 @@ class ggml_webgpu_shader_lib {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
defines.push_back("BYTE_HELPERS");
|
defines.push_back("BYTE_HELPERS");
|
||||||
|
|
@ -1621,8 +1643,7 @@ class ggml_webgpu_shader_lib {
|
||||||
std::string type_upper = src0_name;
|
std::string type_upper = src0_name;
|
||||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||||
|
|
||||||
switch (context.src0->type)
|
switch (context.src0->type) {
|
||||||
{
|
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
|
@ -1642,9 +1663,9 @@ class ggml_webgpu_shader_lib {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
defines.push_back("BYTE_HELPERS");
|
defines.push_back("BYTE_HELPERS");
|
||||||
|
|
@ -2340,6 +2361,47 @@ class ggml_webgpu_shader_lib {
|
||||||
return soft_max_pipelines[key];
|
return soft_max_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
|
ggml_webgpu_conv2d_pipeline_key key = {
|
||||||
|
.weight_type = context.src0->type,
|
||||||
|
.input_type = context.src1->type,
|
||||||
|
.output_type = context.dst->type,
|
||||||
|
};
|
||||||
|
|
||||||
|
auto it = conv2d_pipelines.find(key);
|
||||||
|
if (it != conv2d_pipelines.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> defines;
|
||||||
|
std::string variant = "conv_2d";
|
||||||
|
|
||||||
|
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
||||||
|
std::string s_prefix = prefix;
|
||||||
|
if (type == GGML_TYPE_F32) {
|
||||||
|
defines.push_back(s_prefix + "_F32");
|
||||||
|
} else if (type == GGML_TYPE_F16) {
|
||||||
|
defines.push_back(s_prefix + "_F16");
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Unsupported type for CONV_2D shader");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
push_type_defines("WEIGHT", key.weight_type);
|
||||||
|
push_type_defines("INPUT", key.input_type);
|
||||||
|
push_type_defines("OUTPUT", key.output_type);
|
||||||
|
|
||||||
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||||
|
|
||||||
|
auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
|
||||||
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||||
|
decisions->wg_size = context.max_wg_size;
|
||||||
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||||
|
pipeline.context = decisions;
|
||||||
|
conv2d_pipelines[key] = pipeline;
|
||||||
|
return conv2d_pipelines[key];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||||
std::string shader_code,
|
std::string shader_code,
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include "ggml-backend-impl.h"
|
#include "ggml-backend-impl.h"
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
#include "ggml-webgpu-shader-lib.hpp"
|
#include "ggml-webgpu-shader-lib.hpp"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
#ifdef __EMSCRIPTEN__
|
#ifdef __EMSCRIPTEN__
|
||||||
# include <emscripten/emscripten.h>
|
# include <emscripten/emscripten.h>
|
||||||
|
|
@ -83,7 +84,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||||
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
|
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
|
||||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
|
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
|
||||||
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
|
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
|
||||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256 // enough for 64 parameters
|
||||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||||
|
|
||||||
|
|
@ -907,6 +908,97 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx,
|
||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
|
||||||
|
ggml_tensor * src0,
|
||||||
|
ggml_tensor * src1,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||||
|
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||||
|
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
|
||||||
|
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
|
||||||
|
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
|
||||||
|
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
|
||||||
|
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
|
||||||
|
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||||
|
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||||
|
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||||
|
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||||
|
|
||||||
|
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
|
||||||
|
(uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src0->ne[1],
|
||||||
|
(uint32_t) src0->ne[2],
|
||||||
|
|
||||||
|
(uint32_t) src1->ne[0],
|
||||||
|
(uint32_t) src1->ne[1],
|
||||||
|
|
||||||
|
(uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1],
|
||||||
|
(uint32_t) dst->ne[2],
|
||||||
|
(uint32_t) dst->ne[3],
|
||||||
|
|
||||||
|
(uint32_t) s0,
|
||||||
|
(uint32_t) s1,
|
||||||
|
(uint32_t) p0,
|
||||||
|
(uint32_t) p1,
|
||||||
|
(uint32_t) d0,
|
||||||
|
(uint32_t) d1,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
{ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||||
|
{ .binding = 2,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||||
|
};
|
||||||
|
|
||||||
|
uint32_t max_wg_size =
|
||||||
|
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
|
||||||
|
uint32_t wg_size =
|
||||||
|
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
|
||||||
|
|
||||||
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||||
|
.src0 = src0,
|
||||||
|
.src1 = src1,
|
||||||
|
.dst = dst,
|
||||||
|
.max_wg_size = wg_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
|
||||||
|
|
||||||
|
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||||
|
|
||||||
|
uint32_t n_out = ggml_nelements(dst);
|
||||||
|
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
|
||||||
|
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||||
|
uint32_t wg_x = std::min(total_wg, max_wg);
|
||||||
|
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||||
|
|
||||||
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||||
|
}
|
||||||
|
|
||||||
static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||||
ggml_tensor * src0,
|
ggml_tensor * src0,
|
||||||
ggml_tensor * src1,
|
ggml_tensor * src1,
|
||||||
|
|
@ -2753,6 +2845,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
return ggml_webgpu_sum_rows(ctx, src0, node);
|
return ggml_webgpu_sum_rows(ctx, src0, node);
|
||||||
|
case GGML_OP_CONV_2D:
|
||||||
|
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
||||||
default:
|
default:
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
@ -3781,6 +3875,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_CONV_2D:
|
||||||
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
||||||
|
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
supports_op = op->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
#include "common_decls.tmpl"
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
#if defined(WEIGHT_F32)
|
||||||
|
var<storage, read_write> weights: array<f32>;
|
||||||
|
#elif defined(WEIGHT_F16)
|
||||||
|
var<storage, read_write> weights: array<f16>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
#if defined(INPUT_F32)
|
||||||
|
var<storage, read_write> input: array<f32>;
|
||||||
|
#elif defined(INPUT_F16)
|
||||||
|
var<storage, read_write> input: array<f16>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
#if defined(OUTPUT_F32)
|
||||||
|
var<storage, read_write> output: array<f32>;
|
||||||
|
#elif defined(OUTPUT_F16)
|
||||||
|
var<storage, read_write> output: array<f16>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_w: u32,
|
||||||
|
offset_i: u32,
|
||||||
|
offset_o: u32,
|
||||||
|
|
||||||
|
// element strides
|
||||||
|
sw0: u32, sw1: u32, sw2: u32, sw3: u32,
|
||||||
|
si0: u32, si1: u32, si2: u32, si3: u32,
|
||||||
|
so0: u32, so1: u32, so2: u32, so3: u32,
|
||||||
|
|
||||||
|
// kernel dimensions
|
||||||
|
KW: u32, KH: u32, IC: u32,
|
||||||
|
// input dimensions
|
||||||
|
IW: u32, IH: u32,
|
||||||
|
// output dimensions
|
||||||
|
OW: u32, OH: u32, OC_out: u32, N_out: u32,
|
||||||
|
|
||||||
|
// stride
|
||||||
|
s0: u32, s1: u32,
|
||||||
|
// padding
|
||||||
|
p0: u32, p1: u32,
|
||||||
|
// dilation
|
||||||
|
d0: u32, d1: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn load_weight(idx: u32) -> f32 {
|
||||||
|
#if defined(WEIGHT_F32)
|
||||||
|
return weights[idx];
|
||||||
|
#elif defined(WEIGHT_F16)
|
||||||
|
return f32(weights[idx]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_input(idx: u32) -> f32 {
|
||||||
|
#if defined(INPUT_F32)
|
||||||
|
return input[idx];
|
||||||
|
#elif defined(INPUT_F16)
|
||||||
|
return f32(input[idx]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_output(idx: u32, val: f32) {
|
||||||
|
#if defined(OUTPUT_F32)
|
||||||
|
output[idx] = val;
|
||||||
|
#elif defined(OUTPUT_F16)
|
||||||
|
output[idx] = f16(val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ceil_div_u32(x: u32, y: u32) -> u32 {
|
||||||
|
return (x + y - 1) / y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the first valid kernel index k such that base + k * step >= 0
|
||||||
|
fn first_valid_k(base: i32, step: u32) -> u32 {
|
||||||
|
if (base >= 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ceil_div_u32(u32(-base), step);
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k)
|
||||||
|
fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 {
|
||||||
|
let remaining = i32(limit) - base;
|
||||||
|
if (remaining <= 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return min(k_max, ceil_div_u32(u32(remaining), step));
|
||||||
|
}
|
||||||
|
|
||||||
|
@compute @workgroup_size(WG_SIZE)
|
||||||
|
fn main(
|
||||||
|
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||||
|
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||||
|
) {
|
||||||
|
|
||||||
|
let threads_per_group = u32(WG_SIZE);
|
||||||
|
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
|
||||||
|
let n_out = params.OW * params.OH * params.OC_out * params.N_out;
|
||||||
|
|
||||||
|
var sum: f32 = 0.0;
|
||||||
|
if (i_out >= n_out) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel layout: [KW, KH, IC, ..]
|
||||||
|
// Input layout: [IW, IH, .., ..]
|
||||||
|
// Output layout: [OW, OH, OC, N]
|
||||||
|
|
||||||
|
var i = i_out;
|
||||||
|
let n = i / (params.OC_out * params.OH * params.OW);
|
||||||
|
i = i % (params.OC_out * params.OH * params.OW);
|
||||||
|
let oc = i / (params.OH * params.OW);
|
||||||
|
i = i % (params.OH * params.OW);
|
||||||
|
let oh = i / params.OW;
|
||||||
|
let ow = i % params.OW;
|
||||||
|
|
||||||
|
let ow_base = i32(ow * params.s0) - i32(params.p0);
|
||||||
|
let oh_base = i32(oh * params.s1) - i32(params.p1);
|
||||||
|
|
||||||
|
// clip the valid kernel window once
|
||||||
|
let kw_begin = first_valid_k(ow_base, params.d0);
|
||||||
|
let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW);
|
||||||
|
let kh_begin = first_valid_k(oh_base, params.d1);
|
||||||
|
let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH);
|
||||||
|
|
||||||
|
// entire receptive field is out of bounds
|
||||||
|
if (kw_begin >= kw_end || kh_begin >= kh_end) {
|
||||||
|
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
|
||||||
|
store_output(out_idx, 0.0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let weight_oc_base = params.offset_w + oc * params.sw3;
|
||||||
|
let input_n_base = params.offset_i + n * params.si3;
|
||||||
|
|
||||||
|
for (var ic: u32 = 0; ic < params.IC; ic += 1) {
|
||||||
|
let w_base_ic = ic * params.sw2 + weight_oc_base;
|
||||||
|
let in_base = ic * params.si2 + input_n_base;
|
||||||
|
|
||||||
|
for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) {
|
||||||
|
let ih = u32(oh_base + i32(kh * params.d1));
|
||||||
|
let w_row_base = w_base_ic + kh * params.sw1;
|
||||||
|
let in_row_base = in_base + ih * params.si1;
|
||||||
|
for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) {
|
||||||
|
let iw = u32(ow_base + i32(kw * params.d0));
|
||||||
|
let w_idx = w_row_base + kw * params.sw0;
|
||||||
|
let in_idx = in_row_base + iw * params.si0;
|
||||||
|
|
||||||
|
sum += load_weight(w_idx) * load_input(in_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
|
||||||
|
store_output(out_idx, sum);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue