ggml-webgpu: address quantization precision and backend lifecycle managment (#21521)
* ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops --------- Co-authored-by: Jeremy J. Hartmann <jeremy@mtion.tv>
This commit is contained in:
parent
5dd102539b
commit
e4fed9d08d
|
|
@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib {
|
|||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
|
@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
|
|
@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib {
|
|||
break;
|
||||
default:
|
||||
{
|
||||
// quantized types
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
switch (context.src0->type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
|
|
|||
|
|
@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
|||
|
||||
/* End Constants */
|
||||
|
||||
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
return wgpu::CallbackMode::AllowProcessEvents;
|
||||
#else
|
||||
return wgpu::CallbackMode::AllowSpontaneous;
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
|
|
@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
|
|||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
|
|
@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
|||
std::string callback_message;
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
|
|
@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
|||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->debug_host_buf.GetSize())) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n");
|
||||
return;
|
||||
}
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
|
|
@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &
|
|||
auto ts_bufs = command.timestamp_query_bufs;
|
||||
|
||||
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
|
||||
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
||||
|
|
@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&options, ggml_webgpu_callback_mode(),
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
|
|
@ -3491,8 +3503,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
ggml_webgpu_callback_mode(),
|
||||
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -3525,7 +3537,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&dev_desc, ggml_webgpu_callback_mode(),
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
|
|
@ -4046,6 +4058,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 0;
|
||||
|
||||
// Keep one Dawn/WebGPU instance alive for the lifetime of the static backend
|
||||
// registry. Recreating it on repeated registry lookups can invalidate
|
||||
// adapter/device references that are still held by the backend/device layer.
|
||||
if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
return ®
|
||||
}
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
|
|
@ -4063,11 +4082,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
|
|
|
|||
|
|
@ -9,35 +9,43 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
|||
#endif
|
||||
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn load_src0_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = src0[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
fn load_u16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
return (word >> shift) & 0xFFFF;
|
||||
}
|
||||
|
||||
fn load_src0_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = src0[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = src0[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
fn load_u32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4;
|
||||
let shift = (byte_offset & 0x3) * 8;
|
||||
let lo = buf[word_idx];
|
||||
let hi = buf[word_idx + 1];
|
||||
let shifted = (lo >> shift) | (hi << (32 - shift));
|
||||
return select(shifted, lo, shift == 0);
|
||||
}
|
||||
|
||||
fn load_src0_f16_at(byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
|
||||
fn load_f16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
|
||||
return f16(packed[0]);
|
||||
}
|
||||
|
||||
fn load_f16_as_f32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
let d_bits = (word >> shift) & 0xFFFF;
|
||||
return unpack2x16float(d_bits)[0];
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef Q4_1_T
|
||||
struct q4_1 {
|
||||
|
|
@ -47,13 +55,6 @@ struct q4_1 {
|
|||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_0_T
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_1_T
|
||||
struct q5_1 {
|
||||
|
|
@ -64,12 +65,6 @@ struct q5_1 {
|
|||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_0_T
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
|
|
@ -88,14 +83,6 @@ struct q2_K {
|
|||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q3_K_T
|
||||
struct q3_K {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
|
|
@ -132,64 +119,6 @@ struct q5_K {
|
|||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q6_K_T
|
||||
struct q6_K {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XXS_T
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XS_T
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_S_T
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS_T
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S_T
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_S_T
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_M_T
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
|
|
@ -198,17 +127,9 @@ struct iq1_m {
|
|||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_NL_T
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_XS_T
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
d_scales_h: u32,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
|
|
|
|||
|
|
@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
dst[dst_offset + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef Q5_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
|
|
@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef Q8_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 8u; j++) {
|
||||
let q_byte_offset = block_byte_base + 2u + j * 4u;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
let dst_offset = dst_base + offset * 32u + j * 4u + k;
|
||||
dst[dst_offset] = q_val;
|
||||
}
|
||||
}
|
||||
|
|
@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef Q3_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
|
||||
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
|
|
@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
|
||||
for (var l: u32 = 0; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
dst[dst_i] = (f32(qs_val) - hm) * dl;
|
||||
|
|
@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
for (var i: u32 = 0; i < 16u; i++) {
|
||||
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
|
@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ2_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src, aux0_offset);
|
||||
let aux1 = load_u32_at(&src, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
|
|
@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
|
|
@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
|
|
@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ2_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
|
||||
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
|
|
@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ3_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
|
|
@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ3_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src, block_byte_base + 106);
|
||||
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
|
|
@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
|
|
@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ1_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
|
|
@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
|
||||
#ifdef IQ4_NL
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 32;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
|
|
@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
#ifdef IQ4_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
|
|
|
|||
|
|
@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
|
|
@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
|
|
@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
|
|
@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
|
|
@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
|
|
@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src0, aux0_offset);
|
||||
let aux1 = load_u32_at(&src0, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
|
|
@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
|
|
@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
|
|
@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
|
|
@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src0, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
|
|
@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
|
|
@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
|
|
@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
|
|
@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 32;
|
||||
var sum = 0.0;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
|
|
@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
|
|
|
|||
|
|
@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
|
|
@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
|
|
@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
|
|
@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 80u);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 82u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 80u);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 82u);
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
|
|
@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
|
||||
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
|
|
@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 108u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
|
|
@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
|
|
@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
|
|
@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
|
|
@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
|
|
@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
|
|
@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
|
|
@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
|
||||
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
|
||||
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
|
|
@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
|
||||
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 208u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
|
|
|
|||
|
|
@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = f32(load_src0_f16_at(block_byte_base + 2u));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
|
|
@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
|
|
@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
|
|
@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = f32(load_src0_f16_at(bbase + 208u));
|
||||
let d = f32(load_f16_at(&src0, bbase + 208u));
|
||||
|
||||
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
|
||||
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
|
||||
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
|
||||
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
|
|
|
|||
|
|
@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
|
||||
#endif
|
||||
#ifdef EXP
|
||||
let res = exp(src[params.offset_src + src_idx]);
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32));
|
||||
#endif
|
||||
#ifdef LOG
|
||||
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
|
||||
|
|
@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
|
||||
#endif
|
||||
#ifdef EXPM1
|
||||
let res = exp(src[params.offset_src + src_idx]) - 1.0;
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32) - 1.0);
|
||||
#endif
|
||||
#ifdef FLOOR
|
||||
let res = floor(src[params.offset_src + src_idx]);
|
||||
|
|
@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
|
||||
#endif
|
||||
#ifdef SQRT
|
||||
let res = sqrt(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(sqrt(f32(src[params.offset_src + src_idx])));
|
||||
#endif
|
||||
#ifdef SIN
|
||||
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
|
||||
|
|
|
|||
Loading…
Reference in New Issue