vulkan: fuse rms_norm + mul + rope (+ view + set_rows) (#16977)
This change combines the rms_norm+mul and rope+view+set_rows fusions to allow fusing the whole sequence together. This comes up in Qwen3, Bailing, and some other models.
This commit is contained in:
parent
d6fe40fa00
commit
b4e335d8dc
|
|
@ -466,6 +466,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_ed
|
||||||
{ 2, 0, 1 }, // set_rows->src[0] == view
|
{ 2, 0, 1 }, // set_rows->src[0] == view
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
|
||||||
|
{ 1, 0, 0 }, // mul->src[0] == rms
|
||||||
|
{ 2, 0, 1 }, // rope->src[0] == mul
|
||||||
|
{ 3, 0, 2 }, // view->src[0] == rope
|
||||||
|
{ 4, 0, 3 }, // set_rows->src[0] == view
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
struct vk_device_struct {
|
struct vk_device_struct {
|
||||||
std::recursive_mutex mutex;
|
std::recursive_mutex mutex;
|
||||||
|
|
||||||
|
|
@ -617,6 +625,8 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rms_norm_mul_f32;
|
vk_pipeline pipeline_rms_norm_mul_f32;
|
||||||
vk_pipeline pipeline_rms_norm_partials_f32;
|
vk_pipeline pipeline_rms_norm_partials_f32;
|
||||||
vk_pipeline pipeline_rms_norm_mul_partials_f32;
|
vk_pipeline pipeline_rms_norm_mul_partials_f32;
|
||||||
|
vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
|
||||||
|
vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
|
||||||
vk_pipeline pipeline_rms_norm_back_f32;
|
vk_pipeline pipeline_rms_norm_back_f32;
|
||||||
vk_pipeline pipeline_l2_norm_f32;
|
vk_pipeline pipeline_l2_norm_f32;
|
||||||
|
|
||||||
|
|
@ -1060,6 +1070,7 @@ struct vk_op_diag_mask_push_constants {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_rope_push_constants {
|
struct vk_op_rope_push_constants {
|
||||||
|
uint32_t rope_mode;
|
||||||
uint32_t ncols;
|
uint32_t ncols;
|
||||||
uint32_t n_dims;
|
uint32_t n_dims;
|
||||||
float freq_scale;
|
float freq_scale;
|
||||||
|
|
@ -1079,6 +1090,12 @@ struct vk_op_rope_push_constants {
|
||||||
uint32_t set_rows_stride;
|
uint32_t set_rows_stride;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// For fused rms_norm+mul+rope(+view+set_rows)
|
||||||
|
struct vk_op_rms_norm_mul_rope_push_constants {
|
||||||
|
vk_op_binary_push_constants bin;
|
||||||
|
vk_op_rope_push_constants rope;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_soft_max_push_constants {
|
struct vk_op_soft_max_push_constants {
|
||||||
uint32_t KX;
|
uint32_t KX;
|
||||||
uint32_t KY;
|
uint32_t KY;
|
||||||
|
|
@ -3557,6 +3574,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||||
|
|
||||||
|
if (device->float_controls_rte_fp16 &&
|
||||||
|
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
|
@ -9590,21 +9613,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g
|
||||||
return num_bytes;
|
return num_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) {
|
static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
|
||||||
|
const int n_dims = ((const int32_t *) dst->op_params)[1];
|
||||||
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||||
|
// const int n_ctx = ((const int32_t *) dst->op_params)[3];
|
||||||
|
const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
|
||||||
|
const float freq_base = ((const float *) dst->op_params)[5];
|
||||||
|
const float freq_scale = ((const float *) dst->op_params)[6];
|
||||||
|
const float ext_factor = ((const float *) dst->op_params)[7];
|
||||||
|
const float attn_factor = ((const float *) dst->op_params)[8];
|
||||||
|
const float beta_fast = ((const float *) dst->op_params)[9];
|
||||||
|
const float beta_slow = ((const float *) dst->op_params)[10];
|
||||||
|
int sections[4] {};
|
||||||
|
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||||
|
memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||||
|
|
||||||
|
float corr_dims[2];
|
||||||
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
|
||||||
|
uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||||
|
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||||
|
|
||||||
|
vk_op_rope_push_constants rope {
|
||||||
|
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||||
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||||
|
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
||||||
|
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
||||||
|
};
|
||||||
|
|
||||||
|
return rope;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
|
||||||
|
ggml_tensor * dst;
|
||||||
|
const ggml_tensor * src0;
|
||||||
|
const ggml_tensor * src1;
|
||||||
|
|
||||||
|
if (ctx->num_additional_fused_ops > 0) {
|
||||||
|
// fused rms_norm + mul
|
||||||
|
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||||
|
ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
|
||||||
|
dst = mul;
|
||||||
|
src0 = cgraph->nodes[node_idx]->src[0];
|
||||||
|
src1 = other_src;
|
||||||
|
} else {
|
||||||
|
dst = cgraph->nodes[node_idx];
|
||||||
|
src0 = src1 = dst->src[0];
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
|
|
||||||
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
|
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
|
vk_op_binary_push_constants bin {
|
||||||
(uint32_t)ggml_nelements(src0),
|
(uint32_t)ggml_nelements(src0),
|
||||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
0,
|
0,
|
||||||
op_params[0], 0.0f, (int32_t)param3,
|
op_params[0], 0.0f, (int32_t)param3,
|
||||||
});
|
};
|
||||||
|
|
||||||
|
// more than one fused op means rms_norm+mul+rope
|
||||||
|
if (ctx->num_additional_fused_ops > 1) {
|
||||||
|
static constexpr uint32_t max_tensors = 7;
|
||||||
|
const ggml_tensor *tensors[max_tensors] {};
|
||||||
|
|
||||||
|
ggml_tensor *rms = cgraph->nodes[node_idx + 0];
|
||||||
|
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||||
|
ggml_tensor *rope = cgraph->nodes[node_idx + 2];
|
||||||
|
|
||||||
|
ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
|
||||||
|
|
||||||
|
bool do_set_rows = ctx->num_additional_fused_ops == 4;
|
||||||
|
|
||||||
|
tensors[0] = rms->src[0];
|
||||||
|
tensors[1] = other_src;
|
||||||
|
tensors[2] = mul;
|
||||||
|
tensors[3] = rope->src[1]; // pos
|
||||||
|
tensors[4] = rope->src[2]; // ff
|
||||||
|
tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
|
||||||
|
tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
|
||||||
|
const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
|
||||||
|
|
||||||
|
vk_op_rms_norm_mul_rope_push_constants pc;
|
||||||
|
pc.bin = bin;
|
||||||
|
pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
|
||||||
|
|
||||||
|
vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
|
||||||
|
ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
|
||||||
|
vk_buffer buf[max_tensors];
|
||||||
|
size_t offset[max_tensors];
|
||||||
|
bool uma[max_tensors];
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < max_tensors; ++i) {
|
||||||
|
if (!tensors[i]) {
|
||||||
|
// If any remaining descriptors are unused, just point them at src[0]
|
||||||
|
buf[i] = buf[0];
|
||||||
|
offset[i] = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
|
||||||
|
buf[i] = nullptr;
|
||||||
|
offset[i] = 0;
|
||||||
|
uma[i] = false;
|
||||||
|
|
||||||
|
if (ctx->device->uma) {
|
||||||
|
ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
|
||||||
|
uma[i] = buf[i] != nullptr;
|
||||||
|
}
|
||||||
|
if (!uma[i]) {
|
||||||
|
buf[i] = buf_ctx[i]->dev_buffer;
|
||||||
|
offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(buf[i] != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
|
||||||
|
|
||||||
|
static_assert(max_tensors == 7);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
|
{
|
||||||
|
ggml_vk_subbuffer(ctx, buf[0], offset[0]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[1], offset[1]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[2], offset[2]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[3], offset[3]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[4], offset[4]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[5], offset[5]),
|
||||||
|
ggml_vk_subbuffer(ctx, buf[6], offset[6]),
|
||||||
|
}, pc, elements);
|
||||||
|
} else {
|
||||||
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx->do_add_rms_partials_offset_calculation) {
|
if (ctx->do_add_rms_partials_offset_calculation) {
|
||||||
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
|
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
|
||||||
|
|
@ -9758,9 +9909,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
const float freq_base = ((float *) dst->op_params)[5];
|
const float freq_base = ((float *) dst->op_params)[5];
|
||||||
const float freq_scale = ((float *) dst->op_params)[6];
|
|
||||||
const float ext_factor = ((float *) dst->op_params)[7];
|
|
||||||
const float attn_factor = ((float *) dst->op_params)[8];
|
|
||||||
const float beta_fast = ((float *) dst->op_params)[9];
|
const float beta_fast = ((float *) dst->op_params)[9];
|
||||||
const float beta_slow = ((float *) dst->op_params)[10];
|
const float beta_slow = ((float *) dst->op_params)[10];
|
||||||
int sections[4] {};
|
int sections[4] {};
|
||||||
|
|
@ -9768,16 +9916,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
||||||
|
|
||||||
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
|
|
||||||
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
|
|
||||||
|
|
||||||
uint32_t set_rows_stride = 0;
|
uint32_t set_rows_stride = 0;
|
||||||
// Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
|
// Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
|
||||||
// and overrides the dst and sets src3=row_indices
|
// and overrides the dst and sets src3=row_indices
|
||||||
|
|
@ -9787,12 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
dst = cgraph->nodes[node_idx + 2];
|
dst = cgraph->nodes[node_idx + 2];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
|
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
|
||||||
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
|
||||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
||||||
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
|
||||||
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
|
@ -11307,6 +11444,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
if (n->op == GGML_OP_GLU) {
|
if (n->op == GGML_OP_GLU) {
|
||||||
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
||||||
}
|
}
|
||||||
|
if (n->op == GGML_OP_ROPE) {
|
||||||
|
const int mode = ((const int32_t *) n->op_params)[2];
|
||||||
|
std::cerr << " rope mode: " << mode;
|
||||||
|
}
|
||||||
std::cerr << std::endl;
|
std::cerr << std::endl;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -11414,14 +11555,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
if (ctx->num_additional_fused_ops > 0) {
|
ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
|
||||||
// fused rms_norm + mul
|
|
||||||
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
||||||
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
||||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params);
|
|
||||||
} else {
|
|
||||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
|
||||||
|
|
@ -12407,6 +12541,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check whether the tensors overlap in memory but are not equal.
|
||||||
|
// Fusions can potenitally overwrite src tensors in ways that are not prevented
|
||||||
|
// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
|
||||||
|
// to overlap if they are exactly equal.
|
||||||
|
// XXX TODO this check is probably missing from several fusion optimizations.
|
||||||
|
static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
||||||
|
ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
|
||||||
|
vk_buffer a_buf = a_buf_ctx->dev_buffer;
|
||||||
|
ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
|
||||||
|
vk_buffer b_buf = b_buf_ctx->dev_buffer;
|
||||||
|
if (a_buf == b_buf) {
|
||||||
|
auto a_base = vk_tensor_offset(a) + a->view_offs;
|
||||||
|
auto a_size = ggml_nbytes(a);
|
||||||
|
auto b_base = vk_tensor_offset(b) + b->view_offs;
|
||||||
|
auto b_size = ggml_nbytes(b);
|
||||||
|
|
||||||
|
if (a_base == b_base && a_size == b_size) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((b_base <= a_base && a_base < b_base + b_size) ||
|
||||||
|
(a_base <= b_base && b_base < a_base + a_size)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||||
|
int node_idx) {
|
||||||
|
GGML_UNUSED(ctx);
|
||||||
|
const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
|
||||||
|
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||||
|
const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
|
||||||
|
|
||||||
|
const int mode = ((const int32_t *) rope->op_params)[2];
|
||||||
|
|
||||||
|
// noncontig tensors aren't tested, and don't seem common in practice
|
||||||
|
if (!ggml_is_contiguous(rms) ||
|
||||||
|
!ggml_is_contiguous(mul) ||
|
||||||
|
!ggml_is_contiguous(rope)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// only norm/neox are handled in the shader
|
||||||
|
if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// shared memory size for passing data from mul->rope
|
||||||
|
if (mul->ne[0] > 1024) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// must not overwrite srcs in a way that's not elementwise
|
||||||
|
ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
|
||||||
|
if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
|
||||||
|
ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||||
|
|
||||||
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
||||||
|
|
@ -12552,12 +12750,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
||||||
if (num_adds) {
|
if (num_adds) {
|
||||||
ctx->num_additional_fused_ops = num_adds - 1;
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
||||||
ctx->num_additional_fused_ops = 1;
|
|
||||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
||||||
ctx->num_additional_fused_ops = 1;
|
ctx->num_additional_fused_ops = 1;
|
||||||
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
|
||||||
|
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
|
||||||
|
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
|
||||||
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
|
||||||
|
ctx->num_additional_fused_ops = 4;
|
||||||
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
|
||||||
|
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
|
||||||
|
ctx->num_additional_fused_ops = 2;
|
||||||
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
|
ctx->num_additional_fused_ops = 1;
|
||||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
||||||
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
||||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
||||||
|
|
@ -12790,14 +12996,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
}
|
}
|
||||||
if (ok) {
|
if (ok) {
|
||||||
current_set.push_back(j);
|
current_set.push_back(j);
|
||||||
|
|
||||||
|
int rope_idx = j;
|
||||||
|
|
||||||
|
// When we've found RMS_NORM + MUL, try to find a ROPE that uses it
|
||||||
|
if (j > 0 &&
|
||||||
|
graph->nodes[j]->op == GGML_OP_MUL &&
|
||||||
|
graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
|
||||||
|
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
|
||||||
|
if (graph->nodes[k]->op == GGML_OP_ROPE &&
|
||||||
|
graph->nodes[k]->src[0] == graph->nodes[j] &&
|
||||||
|
// Check that other srcs are already valid
|
||||||
|
graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
|
||||||
|
(graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
|
||||||
|
rope_idx = k;
|
||||||
|
current_set.push_back(rope_idx);
|
||||||
|
used[rope_idx] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// Look for ROPE + VIEW + SET_ROWS and make them consecutive
|
// Look for ROPE + VIEW + SET_ROWS and make them consecutive
|
||||||
if (graph->nodes[j]->op == GGML_OP_ROPE) {
|
if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
|
||||||
int view_idx = -1;
|
int view_idx = -1;
|
||||||
int set_rows_idx = -1;
|
int set_rows_idx = -1;
|
||||||
for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
|
for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
|
||||||
if (view_idx == -1 &&
|
if (view_idx == -1 &&
|
||||||
graph->nodes[k]->op == GGML_OP_VIEW &&
|
graph->nodes[k]->op == GGML_OP_VIEW &&
|
||||||
graph->nodes[k]->src[0] == graph->nodes[j]) {
|
graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
|
||||||
view_idx = k;
|
view_idx = k;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@
|
||||||
|
|
||||||
#include "rte.glsl"
|
#include "rte.glsl"
|
||||||
#include "utils.glsl"
|
#include "utils.glsl"
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
#include "rope_params.glsl"
|
||||||
|
#endif
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
|
|
@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
|
||||||
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
||||||
uint misalign_offsets;
|
uint misalign_offsets;
|
||||||
float param1; float param2; int param3;
|
float param1; float param2; int param3;
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
rope_params rope;
|
||||||
|
#endif
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
#if !RMS_NORM_ROPE_FUSION
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
#endif
|
||||||
|
|
||||||
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
|
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
|
||||||
layout(constant_id = 0) const bool norepeat = false;
|
layout(constant_id = 0) const bool norepeat = false;
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,32 @@
|
||||||
#include "generic_binary_head.glsl"
|
#include "generic_binary_head.glsl"
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
|
|
||||||
|
// data is passed from rms_norm -> rope through shared memory.
|
||||||
|
// rms_norm calls this data_d, rope calls this rope_data_a.
|
||||||
|
// Binding 2 is not used
|
||||||
|
shared FLOAT_TYPE rope_data_a[1024];
|
||||||
|
#define data_d rope_data_a
|
||||||
|
|
||||||
|
layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
|
||||||
|
layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
|
||||||
|
layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
|
||||||
|
layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
|
||||||
|
|
||||||
|
#include "rope_params.glsl"
|
||||||
|
#include "rope_funcs.glsl"
|
||||||
|
|
||||||
|
#define GGML_ROPE_TYPE_NORMAL 0
|
||||||
|
#define GGML_ROPE_TYPE_NEOX 2
|
||||||
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
|
|
||||||
|
|
@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {
|
||||||
|
|
||||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
// Per-row offset in shared memory
|
||||||
|
uint32_t d_offset = 0;
|
||||||
|
#else
|
||||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||||
|
#endif
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||||
|
|
||||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||||
|
|
@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
|
||||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
barrier();
|
||||||
|
rope_params rp = p.rope;
|
||||||
|
uint rope_row = (samp*nchannels + channel)*nrows + row;
|
||||||
|
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
|
||||||
|
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
|
||||||
|
rope_neox(t, rope_row, rp);
|
||||||
|
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
|
||||||
|
rope_norm(t, rope_row, rp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,227 @@
|
||||||
|
|
||||||
|
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||||
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
|
||||||
|
#if RMS_NORM_ROPE_FUSION
|
||||||
|
// Per-row offset in shared memory
|
||||||
|
const uint ix = i0;
|
||||||
|
#else
|
||||||
|
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
|
||||||
|
#endif
|
||||||
|
return ix;
|
||||||
|
}
|
||||||
|
|
||||||
|
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
|
||||||
|
float mscale = p.attn_factor;
|
||||||
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
|
float theta_interp = p.freq_scale * theta_extrap;
|
||||||
|
float theta = theta_interp;
|
||||||
|
if (p.ext_factor != 0.0f) {
|
||||||
|
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
|
||||||
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
|
||||||
|
// Get n-d magnitude scaling corrected for interpolation
|
||||||
|
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
||||||
|
}
|
||||||
|
// Backprogagation uses inverted rotation
|
||||||
|
if (p.is_back != 0) {
|
||||||
|
theta = -theta;
|
||||||
|
}
|
||||||
|
cos_theta = cos(theta) * mscale;
|
||||||
|
sin_theta = sin(theta) * mscale;
|
||||||
|
}
|
||||||
|
|
||||||
|
void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||||
|
uint ne0 = p.ncols;
|
||||||
|
uint ne1 = p.p_delta_rows;
|
||||||
|
|
||||||
|
if (i0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||||
|
const uint i01 = i1 % ne1;
|
||||||
|
const uint i02 = i1 / ne1;
|
||||||
|
|
||||||
|
uint idst = i1*ne0 + i0;
|
||||||
|
const uint ix = rope_a_coord(i0, i01, i02, p);
|
||||||
|
|
||||||
|
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||||
|
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||||
|
if (p.set_rows_stride != 0) {
|
||||||
|
idst = i01*ne0 + i0;
|
||||||
|
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
|
||||||
|
rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||||
|
|
||||||
|
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||||
|
|
||||||
|
const float x0 = float(rope_data_a[ix + 0]);
|
||||||
|
const float x1 = float(rope_data_a[ix + 1]);
|
||||||
|
|
||||||
|
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||||
|
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||||
|
uint ne0 = p.ncols;
|
||||||
|
uint ne1 = p.p_delta_rows;
|
||||||
|
|
||||||
|
if (i0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i01 = i1 % ne1;
|
||||||
|
const uint i02 = i1 / ne1;
|
||||||
|
|
||||||
|
uint idst = i1*ne0 + i0/2;
|
||||||
|
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||||
|
|
||||||
|
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||||
|
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||||
|
if (p.set_rows_stride != 0) {
|
||||||
|
idst = i01*ne0 + i0/2;
|
||||||
|
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||||
|
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||||
|
|
||||||
|
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||||
|
|
||||||
|
const float x0 = float(rope_data_a[ix + 0]);
|
||||||
|
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
|
||||||
|
|
||||||
|
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||||
|
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||||
|
uint ne0 = p.ncols;
|
||||||
|
uint ne1 = p.p_delta_rows;
|
||||||
|
uint ne2 = p.ne02;
|
||||||
|
|
||||||
|
if (i0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i01 = i1 % ne1;
|
||||||
|
const uint i02 = i1 / ne1;
|
||||||
|
|
||||||
|
const uint idst = i1*ne0 + i0/2;
|
||||||
|
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||||
|
|
||||||
|
if (i0 >= p.n_dims) {
|
||||||
|
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||||
|
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||||
|
const int sec_w = p.sections[1] + p.sections[0];
|
||||||
|
const uint sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
|
float theta_base = 0.0;
|
||||||
|
if (p.is_imrope != 0) {
|
||||||
|
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||||
|
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
} else {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (sector < p.sections[0]) {
|
||||||
|
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w + p.sections[2]) {
|
||||||
|
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||||
|
|
||||||
|
const float x0 = float(rope_data_a[ix + 0]);
|
||||||
|
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
|
||||||
|
|
||||||
|
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||||
|
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||||
|
uint ne0 = p.ncols;
|
||||||
|
uint ne1 = p.p_delta_rows;
|
||||||
|
uint ne2 = p.ne02;
|
||||||
|
|
||||||
|
if (i0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i01 = i1 % ne1;
|
||||||
|
const uint i02 = i1 / ne1;
|
||||||
|
|
||||||
|
const uint idst = i1*ne0 + i0/2;
|
||||||
|
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||||
|
|
||||||
|
const int sect_dims = p.sections[0] + p.sections[1];
|
||||||
|
const int sec_w = p.sections[1] + p.sections[0];
|
||||||
|
const uint sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
|
float theta_base = 0.0;
|
||||||
|
if (sector < p.sections[0]) {
|
||||||
|
const uint p0 = sector;
|
||||||
|
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
|
||||||
|
}
|
||||||
|
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||||
|
const uint p0 = sector - p.sections[0];
|
||||||
|
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||||
|
|
||||||
|
float cos_theta, sin_theta;
|
||||||
|
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||||
|
|
||||||
|
const float x0 = float(rope_data_a[ix + 0]);
|
||||||
|
const float x1 = float(rope_data_a[ix + p.n_dims]);
|
||||||
|
|
||||||
|
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||||
|
rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -3,56 +3,18 @@
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
#include "rte.glsl"
|
#include "rte.glsl"
|
||||||
|
#include "rope_params.glsl"
|
||||||
|
|
||||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};
|
||||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
layout (binding = 1) readonly buffer Y {int rope_data_pos[];};
|
||||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
layout (binding = 2) readonly buffer Z {float rope_data_ff[];};
|
||||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};
|
||||||
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
|
layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows
|
||||||
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ncols;
|
rope_params pc;
|
||||||
uint n_dims;
|
};
|
||||||
float freq_scale;
|
|
||||||
uint p_delta_rows;
|
|
||||||
float freq_base;
|
|
||||||
float ext_factor;
|
|
||||||
float attn_factor;
|
|
||||||
float corr_dims[2];
|
|
||||||
float theta_scale;
|
|
||||||
uint has_ff;
|
|
||||||
uint ne02;
|
|
||||||
uint s1;
|
|
||||||
uint s2;
|
|
||||||
int sections[4];
|
|
||||||
uint is_imrope;
|
|
||||||
uint is_back;
|
|
||||||
uint set_rows_stride;
|
|
||||||
} p;
|
|
||||||
|
|
||||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
||||||
return 1.0f - min(1.0f, max(0.0f, y));
|
|
||||||
}
|
|
||||||
|
|
||||||
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
|
|
||||||
float mscale = p.attn_factor;
|
|
||||||
// Get n-d rotational scaling corrected for extrapolation
|
|
||||||
float theta_interp = p.freq_scale * theta_extrap;
|
|
||||||
float theta = theta_interp;
|
|
||||||
if (p.ext_factor != 0.0f) {
|
|
||||||
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
|
|
||||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
||||||
|
|
||||||
// Get n-d magnitude scaling corrected for interpolation
|
|
||||||
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
|
||||||
}
|
|
||||||
// Backprogagation uses inverted rotation
|
|
||||||
if (p.is_back != 0) {
|
|
||||||
theta = -theta;
|
|
||||||
}
|
|
||||||
cos_theta = cos(theta) * mscale;
|
|
||||||
sin_theta = sin(theta) * mscale;
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,70 +1,11 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "rope_head.glsl"
|
#include "rope_head.glsl"
|
||||||
|
#include "rope_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||||
uint ne0 = p.ncols;
|
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||||
uint ne1 = p.p_delta_rows;
|
const uint i1 = gl_GlobalInvocationID.x;
|
||||||
uint ne2 = p.ne02;
|
rope_multi(i0, i1, pc);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
|
||||||
const uint channel_x = row_dst / ne1;
|
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0/2;
|
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
|
||||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
|
||||||
const int sec_w = p.sections[1] + p.sections[0];
|
|
||||||
const uint sector = (i0 / 2) % sect_dims;
|
|
||||||
|
|
||||||
float theta_base = 0.0;
|
|
||||||
if (p.is_imrope != 0) {
|
|
||||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
|
||||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
} else {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (sector < p.sections[0]) {
|
|
||||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
}
|
|
||||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
}
|
|
||||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
}
|
|
||||||
else if (sector >= sec_w + p.sections[2]) {
|
|
||||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
|
||||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const float x0 = float(data_a[ix + 0]);
|
|
||||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
|
||||||
|
|
||||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
|
||||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,11 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "rope_head.glsl"
|
#include "rope_head.glsl"
|
||||||
|
#include "rope_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||||
uint ne0 = p.ncols;
|
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||||
uint ne1 = p.p_delta_rows;
|
const uint i1 = gl_GlobalInvocationID.x;
|
||||||
|
rope_neox(i0, i1, pc);
|
||||||
if (i0 >= ne0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
|
||||||
const uint channel_x = row_dst / ne1;
|
|
||||||
|
|
||||||
uint idst = row_dst*ne0 + i0/2;
|
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
|
||||||
|
|
||||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
|
||||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
|
||||||
if (p.set_rows_stride != 0) {
|
|
||||||
idst = row_x*ne0 + i0/2;
|
|
||||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
|
|
||||||
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
|
||||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const float x0 = float(data_a[ix + 0]);
|
|
||||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
|
||||||
|
|
||||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
|
||||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,11 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "rope_head.glsl"
|
#include "rope_head.glsl"
|
||||||
|
#include "rope_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||||
uint ne0 = p.ncols;
|
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||||
uint ne1 = p.p_delta_rows;
|
const uint i1 = gl_GlobalInvocationID.x;
|
||||||
|
rope_norm(i0, i1, pc);
|
||||||
if (i0 >= ne0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
|
||||||
const uint channel_x = row_dst / ne1;
|
|
||||||
|
|
||||||
uint idst = row_dst*ne0 + i0;
|
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
|
||||||
|
|
||||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
|
||||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
|
||||||
if (p.set_rows_stride != 0) {
|
|
||||||
idst = row_x*ne0 + i0;
|
|
||||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
|
||||||
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
|
|
||||||
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
|
||||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const float x0 = float(data_a[ix + 0]);
|
|
||||||
const float x1 = float(data_a[ix + 1]);
|
|
||||||
|
|
||||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
|
||||||
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
#if !defined(GGML_ROPE_PARAMS)
|
||||||
|
#define GGML_ROPE_PARAMS
|
||||||
|
|
||||||
|
#include "rte.glsl"
|
||||||
|
|
||||||
|
struct rope_params {
|
||||||
|
uint rope_mode;
|
||||||
|
uint ncols;
|
||||||
|
uint n_dims;
|
||||||
|
float freq_scale;
|
||||||
|
uint p_delta_rows;
|
||||||
|
float freq_base;
|
||||||
|
float ext_factor;
|
||||||
|
float attn_factor;
|
||||||
|
float corr_dims[2];
|
||||||
|
float theta_scale;
|
||||||
|
uint has_ff;
|
||||||
|
uint ne02;
|
||||||
|
uint nb01;
|
||||||
|
uint nb02;
|
||||||
|
int sections[4];
|
||||||
|
uint is_imrope;
|
||||||
|
uint is_back;
|
||||||
|
uint set_rows_stride;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // !defined(GGML_ROPE_PARAMS)
|
||||||
|
|
@ -1,47 +1,11 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#include "rope_head.glsl"
|
#include "rope_head.glsl"
|
||||||
|
#include "rope_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||||
uint ne0 = p.ncols;
|
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||||
uint ne1 = p.p_delta_rows;
|
const uint i1 = gl_GlobalInvocationID.x;
|
||||||
uint ne2 = p.ne02;
|
rope_vision(i0, i1, pc);
|
||||||
|
|
||||||
if (i0 >= ne0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint row_dst = gl_GlobalInvocationID.x;
|
|
||||||
|
|
||||||
const uint row_x = row_dst % ne1;
|
|
||||||
const uint channel_x = row_dst / ne1;
|
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0/2;
|
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
|
||||||
|
|
||||||
const int sect_dims = p.sections[0] + p.sections[1];
|
|
||||||
const int sec_w = p.sections[1] + p.sections[0];
|
|
||||||
const uint sector = (i0 / 2) % sect_dims;
|
|
||||||
|
|
||||||
float theta_base = 0.0;
|
|
||||||
if (sector < p.sections[0]) {
|
|
||||||
const uint p0 = sector;
|
|
||||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
|
|
||||||
}
|
|
||||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
|
||||||
const uint p0 = sector - p.sections[0];
|
|
||||||
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
|
||||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
|
||||||
|
|
||||||
const float x0 = float(data_a[ix + 0]);
|
|
||||||
const float x1 = float(data_a[ix + p.n_dims]);
|
|
||||||
|
|
||||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
|
||||||
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -695,6 +695,8 @@ void process_shaders() {
|
||||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||||
|
string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
|
||||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
|
|
@ -840,25 +842,25 @@ void process_shaders() {
|
||||||
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
|
||||||
|
struct test_rms_norm_mul_rope : public test_case {
|
||||||
|
const std::array<int64_t, 4> ne;
|
||||||
|
const float eps;
|
||||||
|
const bool multi_add; // test a sequence of adds feeding into rms_norm
|
||||||
|
const bool set_rows;
|
||||||
|
int mode;
|
||||||
|
|
||||||
|
std::string op_desc(ggml_tensor * t) override {
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
return "RMS_NORM_MUL_ROPE";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool run_whole_graph() override { return true; }
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
|
||||||
|
bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
|
||||||
|
: ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||||
|
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||||
|
ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||||
|
|
||||||
|
if (multi_add) {
|
||||||
|
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
|
||||||
|
|
||||||
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||||
|
|
||||||
|
ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
|
||||||
|
|
||||||
|
ggml_tensor * out;
|
||||||
|
|
||||||
|
if (set_rows) {
|
||||||
|
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
|
||||||
|
|
||||||
|
ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
|
||||||
|
ggml_set_name(dst, "dst");
|
||||||
|
|
||||||
|
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
|
||||||
|
ggml_set_name(row_idxs, "row_idxs");
|
||||||
|
|
||||||
|
out = ggml_set_rows(ctx, dst, view, row_idxs);
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
} else {
|
||||||
|
out = rope;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
|
||||||
|
if (ggml_is_view_op(t->op)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
init_set_rows_row_ids(t, ne[2]);
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_ARGMAX
|
// GGML_OP_ARGMAX
|
||||||
struct test_argmax : public test_case {
|
struct test_argmax : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
|
|
@ -6751,6 +6824,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto multi_add : {false, true}) {
|
||||||
|
for (auto set_rows : {false, true}) {
|
||||||
|
for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||||
|
|
||||||
for (int64_t d_conv : {3, 4}) {
|
for (int64_t d_conv : {3, 4}) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue