Merge branch 'ggml-org:master' into power-law-sampler
This commit is contained in:
commit
ed2890e691
|
|
@ -2105,7 +2105,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"override tensor buffer type", [](common_params & params, const std::string & value) {
|
||||
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
|
||||
}
|
||||
));
|
||||
).set_env("LLAMA_ARG_OVERRIDE_TENSOR"));
|
||||
add_opt(common_arg(
|
||||
{"-otd", "--override-tensor-draft"}, "<tensor name pattern>=<buffer type>,...",
|
||||
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
|
||||
|
|
@ -3536,15 +3536,15 @@ void common_params_add_preset_options(std::vector<common_arg> & args) {
|
|||
[](common_params &, const std::string &) { /* unused */ }
|
||||
).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only());
|
||||
|
||||
args.push_back(common_arg(
|
||||
{"stop-timeout"}, "SECONDS",
|
||||
"in server router mode, force-kill model instance after this many seconds of graceful shutdown",
|
||||
[](common_params &, int) { /* unused */ }
|
||||
).set_env(COMMON_ARG_PRESET_STOP_TIMEOUT).set_preset_only());
|
||||
|
||||
// args.push_back(common_arg(
|
||||
// {"pin"},
|
||||
// "in server router mode, do not unload this model if models_max is exceeded",
|
||||
// [](common_params &) { /* unused */ }
|
||||
// ).set_preset_only());
|
||||
|
||||
// args.push_back(common_arg(
|
||||
// {"unload-idle-seconds"}, "SECONDS",
|
||||
// "in server router mode, unload models idle for more than this many seconds",
|
||||
// [](common_params &, int) { /* unused */ }
|
||||
// ).set_preset_only());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
// pseudo-env variable to identify preset-only arguments
|
||||
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
|
||||
#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT"
|
||||
|
||||
//
|
||||
// CLI argument parsing
|
||||
|
|
|
|||
|
|
@ -7362,6 +7362,90 @@ class MiniMaxM2Model(TextModel):
|
|||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("MiMoV2FlashForCausalLM")
|
||||
class MimoV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MIMO2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
assert self.hparams["swa_head_dim"] == self.hparams["head_dim"]
|
||||
assert self.hparams["swa_num_attention_heads"] == self.hparams["num_attention_heads"]
|
||||
assert self.hparams["swa_v_head_dim"] == self.hparams["v_head_dim"]
|
||||
assert self.hparams["topk_method"] == "noaux_tc"
|
||||
|
||||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
n_head_kv_swa = self.hparams["swa_num_key_value_heads"]
|
||||
n_head_kv_arr = [n_head_kv_swa if use_swa == 1 else n_head_kv for use_swa in self.hparams["hybrid_layer_pattern"]]
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv_arr)
|
||||
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
|
||||
self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
|
||||
self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
|
||||
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
if "attention_sink" in name and not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
|
||||
# TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE
|
||||
if "model.mtp." in name:
|
||||
return []
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("PanguEmbeddedForCausalLM")
|
||||
class PanguEmbeddedModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PANGU_EMBED
|
||||
|
|
@ -8695,6 +8779,11 @@ class NemotronHModel(GraniteHybridModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("LlamaBidirectionalModel")
|
||||
class LlamaEmbedNemotronModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA_EMBED
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeForCausalLM")
|
||||
class BailingMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE
|
||||
|
|
|
|||
|
|
@ -829,7 +829,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
|
||||
No. We can't support Ollama issue directly, because we aren't familiar with Ollama.
|
||||
|
||||
Sugguest reproducing on llama.cpp and report similar issue to llama.cpp. We will surpport it.
|
||||
Suggest reproducing on llama.cpp and report similar issue to llama.cpp. We will support it.
|
||||
|
||||
It's same for other projects including llama.cpp SYCL backend.
|
||||
|
||||
|
|
|
|||
|
|
@ -2338,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
|||
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
|
||||
// TODO: acl_yarn_ramp_tensor use rope cache.
|
||||
bool yarn_ramp_tensor_updated = false;
|
||||
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
|
||||
acl_tensor_ptr acl_yarn_ramp_tensor;
|
||||
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
|
||||
ctx.rope_cache.freq_scale != freq_scale)) {
|
||||
yarn_ramp_tensor_updated = true;
|
||||
|
||||
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
// -rope_yarn_ramp
|
||||
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||
// return MIN(1, MAX(0, y)) - 1;
|
||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||
acl_yarn_ramp_tensor =
|
||||
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
||||
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
||||
float zero_value = 0, one_value = 1;
|
||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||
|
|
@ -2380,8 +2380,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
|||
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
|
||||
} else {
|
||||
acl_yarn_ramp_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
||||
}
|
||||
|
||||
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
|
||||
if (ext_factor != 0) {
|
||||
if (theta_scale_updated || yarn_ramp_tensor_updated) {
|
||||
|
|
@ -2988,32 +2990,156 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
// stride
|
||||
int64_t s0 = ((const int32_t *) (dst->op_params))[0];
|
||||
int64_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
|
||||
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
|
||||
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
|
||||
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
|
||||
|
||||
// get base information of input and kernel
|
||||
int64_t input_len = *(src1->ne);
|
||||
int64_t dst_len = *(dst->ne);
|
||||
int64_t kernel_size = *(src0->ne);
|
||||
|
||||
// set the max kernel size for each conv
|
||||
int64_t max_kernel_size = 255;
|
||||
|
||||
// compute the partition of kernel
|
||||
int64_t part_num = 1;
|
||||
part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
|
||||
|
||||
int64_t strideVal[1];
|
||||
strideVal[0] = s0;
|
||||
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
|
||||
int64_t paddingVal[] = { 0 };
|
||||
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
|
||||
int64_t dilationVal[] = { 1 };
|
||||
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
|
||||
int8_t cubeMathType = 0;
|
||||
strideVal[0] = s0;
|
||||
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
|
||||
int64_t paddingVal[] = {0};
|
||||
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
|
||||
int64_t dilationVal[] = {1};
|
||||
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
|
||||
bool transposed = true;
|
||||
int64_t groups = 1;
|
||||
int8_t cubeMathType = 0;
|
||||
|
||||
#ifdef ASCEND_310P
|
||||
cubeMathType = 1;
|
||||
#endif
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), acl_weight.get(), nullptr, stride.get(), padding.get(),
|
||||
dilation.get(), true, padding.get(), 1, acl_dst.get(), cubeMathType);
|
||||
auto weight_type = ggml_cann_type_mapping(src0->type);
|
||||
auto dst_type = ggml_cann_type_mapping(dst->type);
|
||||
|
||||
// slice the kernel to make each conv available
|
||||
int64_t slice_dim = -1;
|
||||
int64_t slice_start = 0;
|
||||
int64_t slice_end = max_kernel_size;
|
||||
int64_t slice_step = 1;
|
||||
int64_t interval = max_kernel_size;
|
||||
|
||||
int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
|
||||
int64_t right_pad_len = 0;
|
||||
|
||||
acl_scalar_ptr alpha = nullptr;
|
||||
float alphaValue = 1.0;
|
||||
alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
|
||||
|
||||
// set zero to destination
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
|
||||
|
||||
for(int k = 0; k < part_num; k++){
|
||||
|
||||
// create part kernel tensor and slice from big kernel
|
||||
slice_start = max_kernel_size * k;
|
||||
if(k == part_num - 1){
|
||||
slice_end = kernel_size;
|
||||
interval = kernel_size - max_kernel_size * k;
|
||||
}else{
|
||||
slice_end = max_kernel_size * (k+1);
|
||||
}
|
||||
|
||||
int64_t part_ne[4];
|
||||
for(int i = 0; i < 4; i++) {
|
||||
part_ne[i] = *(src0->ne + i);
|
||||
}
|
||||
part_ne[0] = interval;
|
||||
|
||||
size_t part_nb[4];
|
||||
part_nb[0] = sizeof(weight_type);
|
||||
for (int i = 1; i < 4; i++) {
|
||||
part_nb[i] = part_nb[i - 1] * part_ne[i - 1];
|
||||
}
|
||||
|
||||
ggml_cann_pool_alloc part_kernel_allocator;
|
||||
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
|
||||
void* part_kernel_buf = part_kernel_allocator.get();
|
||||
|
||||
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
|
||||
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
|
||||
|
||||
// create the part conv result tensor
|
||||
int64_t part_dst_ne[4];
|
||||
for(int i = 0; i < 4; i++){
|
||||
part_dst_ne[i] = *(dst->ne + i);
|
||||
}
|
||||
part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
|
||||
|
||||
size_t part_dst_nb[4];
|
||||
part_dst_nb[0] = sizeof(weight_type);
|
||||
for (int i = 1; i < 4; i++) {
|
||||
part_dst_nb[i] = part_dst_nb[i - 1] * part_dst_ne[i - 1];
|
||||
}
|
||||
ggml_cann_pool_alloc part_dst_allocator;
|
||||
part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
|
||||
void* part_dst_buf = part_dst_allocator.get();
|
||||
|
||||
acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
|
||||
part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
|
||||
|
||||
// compute part conv transpose 1d
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
|
||||
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
|
||||
|
||||
// compute the position of part result in final result
|
||||
int64_t global_start = slice_start;
|
||||
int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
|
||||
|
||||
left_pad_len = global_start;
|
||||
right_pad_len = dst_len - global_end;
|
||||
|
||||
std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
|
||||
acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
|
||||
|
||||
acl_scalar_ptr pad_value = nullptr;
|
||||
float pad_valueVal = 0.0;
|
||||
pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
|
||||
|
||||
int64_t conv_result_ne[4];
|
||||
for(int i = 0; i < 4; i++){
|
||||
conv_result_ne[i] = *(dst->ne + i);
|
||||
}
|
||||
|
||||
size_t conv_result_nb[4];
|
||||
conv_result_nb[0] = sizeof(weight_type);
|
||||
for (int i = 1; i < 4; i++) {
|
||||
conv_result_nb[i] = conv_result_nb[i - 1] * conv_result_ne[i - 1];
|
||||
}
|
||||
|
||||
ggml_cann_pool_alloc conv_result_allocator;
|
||||
conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
|
||||
void* conv_result_buf = conv_result_allocator.get();
|
||||
|
||||
acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
|
||||
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@
|
|||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_slice.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -229,6 +229,60 @@ struct ggml_graph_node_properties {
|
|||
// op
|
||||
ggml_op node_op;
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
|
||||
/**
|
||||
* @brief Check if a ggml tensor node matches this property set.
|
||||
*
|
||||
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
|
||||
* to determine whether the current node matches these previously recorded properties.
|
||||
*
|
||||
* @param node The current ggml tensor node.
|
||||
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
|
||||
*/
|
||||
bool has_matching_properties(ggml_tensor * node) {
|
||||
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != this->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != this->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (node->nb[i] != this->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i]) {
|
||||
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int d = 0; d < GGML_MAX_DIMS; d++) {
|
||||
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
|
||||
return false;
|
||||
}
|
||||
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (this->src_address[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
|
||||
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cann_graph {
|
||||
|
|
@ -241,6 +295,79 @@ struct ggml_cann_graph {
|
|||
aclmdlRI graph = nullptr;
|
||||
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
|
||||
/**
|
||||
* @brief Create a new CANN graph from a ggml computation graph.
|
||||
*
|
||||
* This function creates a new ggml_cann_graph object and fills its node properties
|
||||
* (operation type, dimensions, strides, input sources, and operation parameters)
|
||||
* based on the current ggml computation graph.
|
||||
*
|
||||
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
|
||||
* - node address
|
||||
* - operation type
|
||||
* - shape (ne) and strides (nb)
|
||||
* - source tensor addresses
|
||||
* - operation parameters
|
||||
*
|
||||
* @param cgraph The current ggml computation graph.
|
||||
* @return Pointer to the newly created ggml_cann_graph object.
|
||||
*/
|
||||
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
|
||||
ggml_cann_graph * new_graph = new ggml_cann_graph();
|
||||
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
|
||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
auto & prop = new_graph->ggml_graph_properties[node_idx];
|
||||
|
||||
prop.node_address = node->data;
|
||||
prop.node_op = node->op;
|
||||
|
||||
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
|
||||
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
|
||||
|
||||
for (int src = 0; src < GGML_MAX_SRC; ++src) {
|
||||
if (node->src[src]) {
|
||||
prop.src_address[src] = node->src[src]->data;
|
||||
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
|
||||
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
|
||||
} else {
|
||||
prop.src_address[src] = nullptr;
|
||||
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
|
||||
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
return new_graph;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check whether this CANN graph matches the given ggml computation graph.
|
||||
*
|
||||
* This function compares the number of nodes and each node's properties
|
||||
* (operation type, dimensions, strides, inputs, and operation parameters)
|
||||
* to determine whether this CANN graph matches the given ggml graph.
|
||||
*
|
||||
* @param cgraph The current ggml computation graph.
|
||||
* @return true if this CANN graph matches the ggml graph; false otherwise.
|
||||
*/
|
||||
bool matches_cgraph(ggml_cgraph * cgraph) {
|
||||
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache {
|
|||
cache_list.push_front(new_node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Move an existing graph to the front of the cache.
|
||||
* @param node Pointer to the ggml_cann_graph to move.
|
||||
*/
|
||||
void move_to_front(ggml_cann_graph * node) {
|
||||
cache_list.remove(node);
|
||||
cache_list.push_front(node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clear all graphs from the cache (also frees memory).
|
||||
*/
|
||||
|
|
@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache {
|
|||
* @brief Destructor that clears the cache and frees all cached graphs.
|
||||
*/
|
||||
~ggml_cann_graph_lru_cache() { clear(); }
|
||||
|
||||
/**
|
||||
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
|
||||
*
|
||||
* This function iterates through the cached CANN graphs stored in the LRU cache and
|
||||
* compares them against the given ggml computation graph. If a matching graph is found,
|
||||
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
|
||||
* returns nullptr.
|
||||
*
|
||||
* @param cgraph The current ggml computation graph.
|
||||
* @return true if found; false otherwise.
|
||||
*/
|
||||
bool find_and_move_to_front(ggml_cgraph * cgraph) {
|
||||
for (auto & graph_ptr : this->cache_list) {
|
||||
if (graph_ptr->matches_cgraph(cgraph)) {
|
||||
cache_list.remove(graph_ptr);
|
||||
cache_list.push_front(graph_ptr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
|
|
@ -318,6 +458,9 @@ struct ggml_cann_rope_cache {
|
|||
if (position_select_index_host) {
|
||||
free(position_select_index_host);
|
||||
}
|
||||
if (yarn_ramp_cache) {
|
||||
ACL_CHECK(aclrtFree(yarn_ramp_cache));
|
||||
}
|
||||
}
|
||||
|
||||
bool equal(int64_t theta_scale_length,
|
||||
|
|
@ -370,6 +513,7 @@ struct ggml_cann_rope_cache {
|
|||
float * theta_scale_exp_host = nullptr;
|
||||
int * position_select_index_host = nullptr;
|
||||
void * position_select_index = nullptr;
|
||||
void * yarn_ramp_cache = nullptr;
|
||||
// sin/cos cache, used only to accelerate first layer on each device
|
||||
void * sin_cache = nullptr;
|
||||
void * cos_cache = nullptr;
|
||||
|
|
|
|||
|
|
@ -2075,162 +2075,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
|||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/**
|
||||
* @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
|
||||
*
|
||||
* This function creates a new ggml_cann_graph object and fills its node properties
|
||||
* (operation type, dimensions, strides, input sources, and operation parameters)
|
||||
* based on the current ggml computation graph.
|
||||
*
|
||||
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
|
||||
* - node address
|
||||
* - operation type
|
||||
* - shape (ne) and strides (nb)
|
||||
* - source tensor addresses
|
||||
* - operation parameters
|
||||
*
|
||||
* After initialization, the new graph is pushed into the LRU cache owned by the
|
||||
* CANN backend context. The cache takes ownership of the graph and manages its
|
||||
* lifetime (including deletion upon eviction).
|
||||
*
|
||||
* @param cann_ctx The CANN backend context containing the graph cache.
|
||||
* @param cgraph The current ggml computation graph.
|
||||
*/
|
||||
static void add_lru_matched_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
||||
// Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
|
||||
ggml_cann_graph * new_graph = new ggml_cann_graph();
|
||||
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
|
||||
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
|
||||
ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
auto & prop = new_graph->ggml_graph_properties[node_idx];
|
||||
|
||||
prop.node_address = node->data;
|
||||
prop.node_op = node->op;
|
||||
|
||||
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
|
||||
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
|
||||
|
||||
for (int src = 0; src < GGML_MAX_SRC; ++src) {
|
||||
if (node->src[src]) {
|
||||
prop.src_address[src] = node->src[src]->data;
|
||||
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
|
||||
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
|
||||
} else {
|
||||
prop.src_address[src] = nullptr;
|
||||
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
|
||||
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
// Insert into the LRU cache (cache takes ownership and will delete it when evicted).
|
||||
cann_ctx->graph_lru_cache.push(new_graph);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
|
||||
*
|
||||
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
|
||||
* to determine whether the current node matches a previously recorded version.
|
||||
*
|
||||
* @param node The current ggml tensor node.
|
||||
* @param graph_node_properties The stored properties of a CANN graph node.
|
||||
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
|
||||
*/
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node,
|
||||
ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != graph_node_properties->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != graph_node_properties->ne[i]) {
|
||||
return false;
|
||||
}
|
||||
if (node->nb[i] != graph_node_properties->nb[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i]) {
|
||||
if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int d = 0; d < GGML_MAX_DIMS; d++) {
|
||||
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
|
||||
return false;
|
||||
}
|
||||
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (graph_node_properties->src_address[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
|
||||
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check whether there is a cached CANN graph that matches the current ggml graph.
|
||||
*
|
||||
* This function iterates through the cached CANN graphs stored in the LRU cache and
|
||||
* compares them against the given ggml computation graph. A match requires that the
|
||||
* number of nodes is the same and that each node’s properties (operation type,
|
||||
* dimensions, strides, inputs, and operation parameters) are identical.
|
||||
*
|
||||
* If a matching graph is found, it is promoted to the front of the LRU cache and the
|
||||
* function returns true. Otherwise, the function returns false, indicating that a new
|
||||
* CANN graph needs to be captured.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context containing the graph cache.
|
||||
* @param cgraph The current ggml computation graph.
|
||||
* @return true if a matching cached graph exists; false otherwise.
|
||||
*/
|
||||
static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
|
||||
ggml_cann_graph_lru_cache & lru_cache = cann_ctx->graph_lru_cache;
|
||||
for (auto & graph_ptr : lru_cache.cache_list) {
|
||||
// Skip graphs with a different number of nodes.
|
||||
if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if all nodes match.
|
||||
bool all_match = true;
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
|
||||
all_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (all_match) {
|
||||
// update cache_list && renturn graph_ptr
|
||||
lru_cache.move_to_front(graph_ptr);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
/**
|
||||
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
||||
*
|
||||
|
|
@ -2239,23 +2083,23 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
|
|||
*
|
||||
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context.
|
||||
* @param cgraph The ggml computation graph.
|
||||
* @param use_cann_graph Whether to use CANN graph execution.
|
||||
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
|
||||
* @param cann_ctx The CANN backend context.
|
||||
* @param cgraph The ggml computation graph.
|
||||
* @param use_cann_graph Whether to use CANN graph execution.
|
||||
* @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
|
||||
*/
|
||||
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
|
||||
ggml_cgraph * cgraph,
|
||||
bool & use_cann_graph,
|
||||
bool & cann_graph_update_required) {
|
||||
bool use_cann_graph,
|
||||
bool cann_graph_capture_required) {
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph && cann_graph_update_required) { // Begin CANN graph capture
|
||||
if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
|
||||
}
|
||||
#endif // USE_ACL_GRAPH
|
||||
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cann_graph || cann_graph_update_required) {
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
|
@ -2274,9 +2118,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (use_cann_graph) {
|
||||
GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
|
||||
ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
|
||||
|
||||
if (cann_graph_update_required) { // End CANN graph capture
|
||||
if (cann_graph_capture_required) { // End CANN graph capture
|
||||
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
|
||||
}
|
||||
|
||||
|
|
@ -2306,7 +2151,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
// calculate rope cache for fist layer in current device.
|
||||
cann_ctx->rope_cache.cached = false;
|
||||
|
||||
bool cann_graph_update_required = false;
|
||||
bool graph_capture_required = false;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
|
||||
|
|
@ -2331,16 +2176,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
|
||||
if (use_cann_graph) {
|
||||
// If no matching graph is found, the graph needs to be recaptured.
|
||||
cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
|
||||
if (cann_graph_update_required) {
|
||||
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
|
||||
if (graph_capture_required) {
|
||||
// If no matching graph is found, add a new ACL graph.
|
||||
add_lru_matched_graph_node_properties(cann_ctx, cgraph);
|
||||
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
|
||||
cann_ctx->graph_lru_cache.push(new_graph);
|
||||
}
|
||||
}
|
||||
#else
|
||||
bool use_cann_graph = false;
|
||||
#endif // USE_ACL_GRAPH
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
|
||||
evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
@ -2578,8 +2424,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
}
|
||||
}
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
|
||||
return (op->src[0]->ne[0] - 1) <= 255;
|
||||
return true;
|
||||
case GGML_OP_SCALE:
|
||||
float bias;
|
||||
memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
|
|||
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
|
||||
# 86 == RTX 3000, needs CUDA v11.1
|
||||
# 89 == RTX 4000, needs CUDA v11.8
|
||||
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
|
||||
#
|
||||
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
|
||||
# XX-real == compile CUDA code as device code for this specific architecture
|
||||
|
|
@ -40,6 +41,32 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
enable_language(CUDA)
|
||||
|
||||
# Replace any 12x-real architectures with 12x{a}-real. FP4 ptx instructions are not available in just 12x
|
||||
if (GGML_NATIVE)
|
||||
set(PROCESSED_ARCHITECTURES "")
|
||||
if (CMAKE_CUDA_ARCHITECTURES_NATIVE)
|
||||
set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
|
||||
else()
|
||||
set(ARCH_LIST ${CMAKE_CUDA_ARCHITECTURES})
|
||||
endif()
|
||||
foreach(ARCH ${ARCH_LIST})
|
||||
if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
|
||||
string(REGEX REPLACE "^(12[0-9]).*$" "\\1" BASE_ARCH ${ARCH})
|
||||
message(STATUS "Replacing ${ARCH} with ${BASE_ARCH}a-real")
|
||||
list(APPEND PROCESSED_ARCHITECTURES "${BASE_ARCH}a-real")
|
||||
else()
|
||||
list(APPEND PROCESSED_ARCHITECTURES ${ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
set(CMAKE_CUDA_ARCHITECTURES ${PROCESSED_ARCHITECTURES})
|
||||
else()
|
||||
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
|
||||
if(ARCH MATCHES "^12[0-9]$")
|
||||
message(FATAL_ERROR "Compute capability ${ARCH} used, use ${ARCH}a or ${ARCH}f for Blackwell specific optimizations")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
file(GLOB GGML_HEADERS_CUDA "*.cuh")
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,10 @@
|
|||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_ADA_LOVELACE 890
|
||||
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
|
||||
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
|
||||
#define GGML_CUDA_CC_BLACKWELL 1200
|
||||
#define GGML_CUDA_CC_RUBIN 1300
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
||||
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
|
||||
|
|
@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#define AMPERE_MMA_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
|
||||
# define BLACKWELL_MMA_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
|
@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) {
|
|||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
||||
}
|
||||
|
||||
static bool blackwell_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
|
||||
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
return 64;
|
||||
|
|
@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
|||
#endif // CUDART_VERSION >= 12050
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
const uint8_t sign_bit = (x < 0.0f) << 3;
|
||||
float ax = fabsf(x) * e;
|
||||
|
||||
// Positive LUT
|
||||
static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
|
||||
|
||||
int best_i = 0;
|
||||
float best_err = fabsf(ax - pos_lut[0]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 8; ++i) {
|
||||
const float err = fabsf(ax - pos_lut[i]);
|
||||
if (err < best_err) {
|
||||
best_err = err;
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<uint8_t>(best_i | sign_bit);
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/device/device_scan.cuh>
|
||||
# include <cub/block/block_scan.cuh>
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
template<typename T, int BLOCK_SIZE>
|
||||
|
|
@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
|
|||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
|
||||
using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
|
||||
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
__shared__ T block_carry; // carry from previous tile
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
__shared__ T block_carry;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int UNROLL_FACTOR = 4;
|
||||
constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
|
|
@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
|
||||
int64_t idx = start + tid;
|
||||
T x = (idx < ne00) ? src_row[idx] : T(0);
|
||||
for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
|
||||
T items[UNROLL_FACTOR];
|
||||
T thread_sum = T(0);
|
||||
|
||||
T inclusive;
|
||||
T block_total;
|
||||
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
T final_val = inclusive + block_carry;
|
||||
|
||||
// store result
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = final_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
||||
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
||||
T val = (idx < ne00) ? src_row[idx] : T(0);
|
||||
thread_sum += val;
|
||||
items[i] = thread_sum;
|
||||
}
|
||||
|
||||
// Block-wide scan on thread sums
|
||||
T thread_prefix;
|
||||
T block_total;
|
||||
BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
|
||||
__syncthreads();
|
||||
|
||||
// Add offset to each item and store
|
||||
T thread_offset = thread_prefix - thread_sum + block_carry;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
||||
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = items[i] + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
// Update carry for next tile
|
||||
if (tid == 0) {
|
||||
block_carry += block_total;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
#else
|
||||
|
|
@ -69,7 +80,7 @@ static __global__ void cumsum_cub_kernel(
|
|||
#endif // GGML_CUDA_USE_CUB
|
||||
}
|
||||
|
||||
// Fallback kernel implementation (original)
|
||||
// Fallback kernel implementation
|
||||
template<typename T>
|
||||
static __global__ void cumsum_kernel(
|
||||
const T * src, T * dst,
|
||||
|
|
@ -86,10 +97,10 @@ static __global__ void cumsum_kernel(
|
|||
const int warps_per_block = blockDim.x / warp_size;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + blockDim.x;
|
||||
float * s_carry = smem + blockDim.x + warps_per_block;
|
||||
float * s_chunk_total = s_carry + 1;
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + blockDim.x;
|
||||
float * s_carry = smem + blockDim.x + warps_per_block;
|
||||
float * s_chunk_total = s_carry + 1;
|
||||
|
||||
// Initialize carry
|
||||
if (tid == 0) {
|
||||
|
|
@ -107,21 +118,39 @@ static __global__ void cumsum_kernel(
|
|||
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
|
||||
for (int64_t start = 0; start < ne00; start += blockDim.x) {
|
||||
int64_t idx = start + tid;
|
||||
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
|
||||
// register blocking: process 4 elements per thread to hide latency
|
||||
// and reduce synchronization overhead
|
||||
constexpr int num_unroll = 4;
|
||||
T temp[num_unroll];
|
||||
|
||||
// 1. Warp inclusive scan
|
||||
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
|
||||
int64_t idx = i + tid * num_unroll;
|
||||
|
||||
// thread local sequential scan
|
||||
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
|
||||
#pragma unroll
|
||||
for (int64_t j = 1; j < num_unroll; j++) {
|
||||
temp[j] = temp[j - 1];
|
||||
if (idx + j < ne00) {
|
||||
temp[j] += src_row[idx + j];
|
||||
} else {
|
||||
temp[j] += 0;
|
||||
}
|
||||
}
|
||||
|
||||
// last emenent is sum of all values assigned to thread
|
||||
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
|
||||
|
||||
// Warp inclusive scan
|
||||
val = warp_prefix_inclusive_sum<T, warp_size>(val);
|
||||
s_vals[tid] = val;
|
||||
|
||||
// Store warp total
|
||||
if (lane == warp_size - 1) {
|
||||
s_warp_sums[warp] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 2. Exclusive scan of warp sums (warp 0 only)
|
||||
// Exclusive scan of warp sums (warp 0 only)
|
||||
if (warp == 0) {
|
||||
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
||||
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
|
||||
|
|
@ -134,12 +163,17 @@ static __global__ void cumsum_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
// write back results
|
||||
float carry = *s_carry;
|
||||
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
|
||||
// calculate sum offset for this thread
|
||||
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < num_unroll; j++) {
|
||||
if (idx + j < ne00) {
|
||||
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update carry for next chunk
|
||||
if (tid == 0) {
|
||||
|
|
@ -177,7 +211,7 @@ static void cumsum_cuda(
|
|||
const int warps_per_block = block_size / warp_size;
|
||||
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
|
||||
|
||||
if (use_cub) {
|
||||
if (use_cub && ne00 >= 1024) {
|
||||
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
|
||||
src, dst,
|
||||
ne00, ne01, ne02, ne03,
|
||||
|
|
|
|||
|
|
@ -900,6 +900,27 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
uint32_t a_scale,
|
||||
uint32_t b_scale) {
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
float * Dxi = (float *) D.x;
|
||||
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
|
||||
"%10, {0, 0}, %11, {0, 0};"
|
||||
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#ifdef TURING_MMA_AVAILABLE
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "common.cuh"
|
||||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
|
@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q(
|
|||
const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|
||||
|| GGML_CUDA_CC_IS_CDNA(cc);
|
||||
|
||||
// TODO: tighter pool buffer size vs q8 path
|
||||
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
|
||||
|
||||
if (!ids) {
|
||||
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
|
|
@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
if (use_native_mxfp4) {
|
||||
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
|
||||
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
|
||||
ne11, ne12, ne13, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
// Stride depends on quantization format
|
||||
const int64_t s12 = use_native_mxfp4 ?
|
||||
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
|
||||
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
|
||||
:
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
const mmq_args args = {
|
||||
|
|
@ -175,12 +191,19 @@ void ggml_cuda_mul_mat_q(
|
|||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
|
||||
if (use_native_mxfp4) {
|
||||
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
} else {
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
|
||||
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
|
||||
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
|
||||
const int64_t s13 = ne12*s12;
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
|
|||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_ITER_K_MXFP4_FP4 512
|
||||
#define MMQ_NWARPS 8
|
||||
|
||||
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
||||
|
|
@ -44,8 +45,15 @@ struct block_q8_1_mmq {
|
|||
};
|
||||
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
||||
};
|
||||
|
||||
struct block_fp4_mmq {
|
||||
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
||||
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
||||
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
||||
static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
|
||||
|
||||
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
switch (type_x) {
|
||||
|
|
@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
|
|||
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
||||
#else
|
||||
return MMQ_ITER_K;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(RDNA1)
|
||||
|
|
@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
|
|
@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|||
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
||||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
switch (type) {
|
||||
|
|
@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
|
|
@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
}
|
||||
|
||||
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
||||
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
|
||||
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
|
||||
#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
|
||||
|
||||
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
||||
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
|
||||
|
|
@ -761,6 +782,50 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kbx0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
int * x_qs = (int *) x_tile;
|
||||
uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
|
||||
const int txi = threadIdx.x;
|
||||
|
||||
constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
|
||||
|
||||
constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
const int kbx = txi % threads_per_row;
|
||||
const int row_in_warp = txi / threads_per_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
||||
|
||||
// quantize_mxfp4_mmq permutes nibbles to match the quantized format
|
||||
const int k0 = kbx * 4;
|
||||
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
|
||||
|
||||
// Load E8M0 scales: pack 2 consecutive scales into one uint32
|
||||
if (kbx % 2 == 0) {
|
||||
uint32_t e = bxi->e;
|
||||
e |= ((bxi + 1)->e << 8);
|
||||
x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
||||
const int * __restrict__ y,
|
||||
float * __restrict__ sum,
|
||||
const int k00) {
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<8, 8, int> tile_B;
|
||||
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
||||
|
||||
// Match layout from load_tiles_mxfp4_fp4
|
||||
const int * x_qs = (const int *) x;
|
||||
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const uint32_t * y_sc = (const uint32_t *) y;
|
||||
|
||||
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
||||
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
||||
|
||||
// Block scale
|
||||
// Each thread has to point to a 4 byte scale value
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
||||
MMQ_MMA_TILE_X_K_FP4);
|
||||
|
||||
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
||||
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
||||
scaleA[n][k01 / (2 * QI_MXFP4)] =
|
||||
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
||||
tile_B B;
|
||||
uint32_t scaleB; // 2xN scales
|
||||
|
||||
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
||||
|
||||
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
tile_C C;
|
||||
|
||||
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -3109,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
||||
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
||||
#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
||||
#else
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
#endif // BLACKWELL_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
|
|
@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
#if defined(BLACKWELL_MMA_AVAILABLE)
|
||||
// FP4 tile stores 8 blocks
|
||||
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
||||
#else
|
||||
constexpr int ne_block = 4 * QK8_1;
|
||||
#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
||||
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
|
||||
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
||||
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
||||
|
||||
{
|
||||
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
||||
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
||||
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
tile_y[l] = by0[l];
|
||||
|
|
@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|||
__syncthreads();
|
||||
|
||||
{
|
||||
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
||||
for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
||||
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
||||
|
||||
tile_y[l] = by0[l];
|
||||
|
|
@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q(
|
|||
}
|
||||
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
||||
|
|
@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
||||
offset_dst += it*mmq_y;
|
||||
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
|
|
@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
||||
offset_dst += it*mmq_y;
|
||||
|
||||
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
||||
|
|
@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
|
||||
constexpr int blocks_per_iter = ITER_K / qk;
|
||||
const int64_t blocks_per_ne00 = ncols_x / qk;
|
||||
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
|
|
@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||
const size_t nbs_ids = mmq_x*sizeof(int);
|
||||
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||
const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
|
||||
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,131 @@ static __global__ void quantize_q8_1(
|
|||
y[ib].ds = make_half2(d, sum);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
|
||||
if (!(amax > 0.0f)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// FP4 E2M1: max exponent (unbiased) is 2.
|
||||
constexpr int FP4_E2M1_EMAX = 2;
|
||||
|
||||
const float e = log2f(amax);
|
||||
|
||||
// "even" -> round-to-nearest integer, ties-to-even
|
||||
const int e_int = __float2int_rn(e);
|
||||
|
||||
const int shared_exp = e_int - FP4_E2M1_EMAX;
|
||||
|
||||
int biased = shared_exp + 127;
|
||||
|
||||
biased = max(biased, 0);
|
||||
biased = min(biased, 254);
|
||||
|
||||
return static_cast<uint8_t>(biased);
|
||||
}
|
||||
|
||||
// quantize values in the format mxfp4 is stored which is interleaved nibbles
|
||||
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
|
||||
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
||||
const int32_t * __restrict__ ids,
|
||||
void * __restrict__ vy,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int ne1,
|
||||
const int ne2) {
|
||||
constexpr int vals_per_scale = 32;
|
||||
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
|
||||
|
||||
const int warp_id = threadIdx.y;
|
||||
const int lane_id_32 = threadIdx.x;
|
||||
|
||||
const int nwarps = blockDim.y;
|
||||
|
||||
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
|
||||
|
||||
if (warp_start_offset >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
const int64_t i01 = ids ? ids[i1] : i1;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
|
||||
block_fp4_mmq * y = (block_fp4_mmq *) vy;
|
||||
|
||||
const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
|
||||
const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
|
||||
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
|
||||
const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
|
||||
|
||||
const int group_id = lane_id_32 / 4;
|
||||
const int lane_in_group = lane_id_32 % 4;
|
||||
const int base = group_id * 2;
|
||||
char2 * yqs2 = (char2 *) y[ib].qs;
|
||||
|
||||
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
|
||||
|
||||
uint8_t scales[2];
|
||||
|
||||
#pragma unroll
|
||||
for (int b = 0; b < 2; ++b) {
|
||||
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
|
||||
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
|
||||
|
||||
float amax = fabsf(xi);
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
const uint8_t e = compute_e8m0_scale(amax);
|
||||
scales[b] = e;
|
||||
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
|
||||
|
||||
#if CUDART_VERSION >= 12080
|
||||
const float scaled_val = xi * inv_s;
|
||||
|
||||
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
|
||||
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
|
||||
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
|
||||
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
|
||||
|
||||
if (lane_in_group == 0) {
|
||||
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
|
||||
|
||||
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
|
||||
}
|
||||
#else
|
||||
// Fallback: manual FP4 conversion using LUT
|
||||
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
|
||||
|
||||
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
|
||||
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
|
||||
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
|
||||
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
|
||||
|
||||
if (lane_in_group == 0) {
|
||||
char2 q;
|
||||
q.x = (q_hi_0 << 4) | q_lo_0;
|
||||
q.y = (q_hi_1 << 4) | q_lo_1;
|
||||
yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
|
||||
}
|
||||
#endif // CUDART_VERSION >= 12080
|
||||
}
|
||||
|
||||
if (lane_id_32 == 0) {
|
||||
// Store 2 scales packed into 1 uint32
|
||||
y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1(
|
||||
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
||||
|
|
@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda(
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
[[maybe_unused]] const ggml_type type_src0,
|
||||
const int64_t ne00,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
|
||||
|
||||
constexpr int nwarps = 8;
|
||||
constexpr int vals_per_warp = 2 * QK_MXFP4;
|
||||
constexpr int vals_per_block = nwarps * vals_per_warp;
|
||||
|
||||
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
|
||||
const dim3 block_size(WARP_SIZE, nwarps, 1);
|
||||
|
||||
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
|
|||
const float * x, const int32_t * ids, void * vy,
|
||||
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_mxfp4_cuda(const float * x,
|
||||
const int32_t * ids,
|
||||
void * vy,
|
||||
ggml_type type_src0,
|
||||
int64_t ne00,
|
||||
int64_t s01,
|
||||
int64_t s02,
|
||||
int64_t s03,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3,
|
||||
cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@
|
|||
#include <cuda_fp8.h>
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
|
||||
#if CUDART_VERSION >= 12080
|
||||
#include <cuda_fp4.h>
|
||||
#endif // CUDART_VERSION >= 12080
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
|
|
|
|||
|
|
@ -379,18 +379,18 @@ enum FaCodePath {
|
|||
};
|
||||
|
||||
struct vk_fa_pipeline_state {
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
|
||||
|
||||
uint32_t HSK, HSV;
|
||||
bool small_rows;
|
||||
bool small_rows, small_cache;
|
||||
FaCodePath path;
|
||||
bool aligned;
|
||||
bool f32acc;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
|
||||
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -2582,10 +2582,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
if (hsv >= 192) {
|
||||
return 2;
|
||||
} else if ((hsv | hsk) & 8) {
|
||||
} else if ((hsv | hsk) & 8 || small_cache) {
|
||||
return 4;
|
||||
} else {
|
||||
return 8;
|
||||
|
|
@ -2607,9 +2607,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
|
||||
GGML_UNUSED(clamp);
|
||||
GGML_UNUSED(hsv);
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
|
|
@ -2618,9 +2617,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
if ((hsv | hsk) & 8) {
|
||||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
||||
} else {
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2649,8 +2648,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
return {64, 64};
|
||||
}
|
||||
|
||||
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
|
||||
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
|
||||
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
|
||||
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
|
|
@ -2992,11 +2991,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
|
|
@ -3006,7 +3005,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
|
|
@ -3021,21 +3020,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
uint32_t HSK = fa.first.HSK; \
|
||||
uint32_t HSV = fa.first.HSV; \
|
||||
bool small_rows = fa.first.small_rows; \
|
||||
bool small_cache = fa.first.small_cache; \
|
||||
FaCodePath path = fa.first.path; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
|
@ -8008,11 +8008,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
|
|
@ -8136,6 +8136,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool small_cache = nek1 < 1024;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
uint32_t max_gqa;
|
||||
|
|
@ -8143,7 +8145,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
|
|
@ -8177,7 +8179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||
if (path == FA_SCALAR &&
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
|
||||
small_rows = true;
|
||||
}
|
||||
|
||||
|
|
@ -8193,7 +8195,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
// the "aligned" shader variant will forcibly align strides, for performance
|
||||
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
||||
|
|
@ -8205,7 +8207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
|
|
@ -13716,6 +13718,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
}
|
||||
|
||||
static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
|
|
@ -13745,6 +13748,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
|||
}
|
||||
|
||||
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
|
|
@ -13760,6 +13764,8 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
|||
}
|
||||
|
||||
ggml_vk_wait_events(transfer_ctx, {vkev->event});
|
||||
ggml_vk_ctx_end(transfer_ctx);
|
||||
ctx->transfer_ctx.reset();
|
||||
}
|
||||
|
||||
// TODO: enable async and synchronize
|
||||
|
|
@ -14519,6 +14525,7 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
|
|||
}
|
||||
|
||||
static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
|
|
|||
|
|
@ -449,6 +449,8 @@ class MODEL_ARCH(IntEnum):
|
|||
RND1 = auto()
|
||||
PANGU_EMBED = auto()
|
||||
MISTRAL3 = auto()
|
||||
MIMO2 = auto()
|
||||
LLAMA_EMBED = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -844,6 +846,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.RND1: "rnd1",
|
||||
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
|
||||
MODEL_ARCH.MISTRAL3: "mistral3",
|
||||
MODEL_ARCH.MIMO2: "mimo2",
|
||||
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -3196,6 +3200,46 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.MIMO2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_SINKS,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.LLAMA_EMBED: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -320,6 +320,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ATTN_SINKS: (
|
||||
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||
"model.layers.{bid}.self_attn.attention_sink_bias", # mimov2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_GATE: (
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ add_library(llama
|
|||
models/llama-iswa.cpp
|
||||
models/llama.cpp
|
||||
models/mamba.cpp
|
||||
models/mimo2-iswa.cpp
|
||||
models/minicpm3.cpp
|
||||
models/minimax-m2.cpp
|
||||
models/modern-bert.cpp
|
||||
|
|
|
|||
|
|
@ -115,6 +115,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_RND1, "rnd1" },
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_MIMO2, "mimo2" },
|
||||
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -500,6 +502,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_DECI:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
|
|
@ -2188,6 +2191,27 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_VISEXP_FFN_DOWN,
|
||||
LLM_TENSOR_VISEXP_FFN_UP,
|
||||
};
|
||||
case LLM_ARCH_MIMO2:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS,
|
||||
LLM_TENSOR_FFN_UP_EXPS,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
};
|
||||
case LLM_ARCH_GPTJ:
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -119,6 +119,8 @@ enum llm_arch {
|
|||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_MIMO2,
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -123,10 +123,11 @@ struct llama_hparams {
|
|||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// if swa_layers[il] == 1, then layer il is SWA
|
||||
// if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
// note: using uint32_t type for compatibility reason
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_230B_A10B: return "230B.A10B";
|
||||
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
||||
case LLM_TYPE_300B_A47B: return "300B.A47B";
|
||||
case LLM_TYPE_310B_A15B: return "310B.A15B";
|
||||
case LLM_TYPE_355B_A32B: return "355B.A32B";
|
||||
case LLM_TYPE_E2B: return "E2B";
|
||||
case LLM_TYPE_E4B: return "E4B";
|
||||
|
|
@ -606,7 +607,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
|
||||
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
|
||||
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
|
||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||
}
|
||||
|
|
@ -630,6 +631,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
// arch-specific KVs
|
||||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
|
|
@ -2338,6 +2340,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 48: type = LLM_TYPE_310B_A15B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -2652,6 +2670,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -6646,6 +6665,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
|
||||
uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
|
||||
uint32_t n_head = hparams.n_head(i);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// non-MoE branch
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// MoE branch
|
||||
int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -7269,16 +7326,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
llm = std::make_unique<llm_build_llama>(*this, params);
|
||||
llm = std::make_unique<llm_build_llama<false>>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
|
||||
llm = std::make_unique<llm_build_llama>(*this, params);
|
||||
llm = std::make_unique<llm_build_llama<false>>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
{
|
||||
llm = std::make_unique<llm_build_llama<true>>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
llm = std::make_unique<llm_build_deci>(*this, params);
|
||||
|
|
@ -7704,6 +7765,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_mistral3>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_mimo2_iswa>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7874,6 +7939,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
@ -7933,6 +7999,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
case LLM_ARCH_MIMO2:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ enum llm_type {
|
|||
LLM_TYPE_230B_A10B, // Minimax M2
|
||||
LLM_TYPE_235B_A22B,
|
||||
LLM_TYPE_300B_A47B, // Ernie MoE big
|
||||
LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
|
||||
LLM_TYPE_355B_A32B, // GLM-4.5
|
||||
LLM_TYPE_E2B,
|
||||
LLM_TYPE_E4B,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
template <bool embed>
|
||||
llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
|
|||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
using inp_attn_type = std::conditional_t<embed, llm_graph_input_attn_no_cache, llm_graph_input_attn_kv>;
|
||||
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
if constexpr (embed) {
|
||||
inp_attn = build_attn_inp_no_cache();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
|
|
@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
|
|||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
if constexpr (!embed) {
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
template struct llm_build_llama<false>;
|
||||
template struct llm_build_llama<true>;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
|
||||
#include "models.h"
|
||||
|
||||
llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
uint32_t n_head_l = hparams.n_head(il);
|
||||
uint32_t n_head_kv_l = hparams.n_head_kv(il);
|
||||
const float freq_base_l = model.get_rope_freq_base(cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// self_attention
|
||||
{
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
ggml_tensor * sinks = model.layers[il].attn_sinks;
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
// dense branch
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false,
|
||||
0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -303,6 +303,7 @@ struct llm_build_llada_moe : public llm_graph_context {
|
|||
llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
template <bool embed>
|
||||
struct llm_build_llama : public llm_graph_context {
|
||||
llm_build_llama(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
@ -315,6 +316,10 @@ struct llm_build_mamba : public llm_graph_context_mamba {
|
|||
llm_build_mamba(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mimo2_iswa : public llm_graph_context {
|
||||
llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_minicpm3 : public llm_graph_context {
|
||||
llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
LOG_INF("%s: printing fitted CLI arguments to stdout...\n", __func__);
|
||||
std::this_thread::sleep_for(10ms); // to avoid a race between stderr and stdout
|
||||
common_log_flush(common_log_main());
|
||||
printf("-c %" PRIu32 " -ngl %" PRIu32, cparams.n_ctx, mparams.n_gpu_layers);
|
||||
|
||||
size_t nd = llama_max_devices();
|
||||
|
|
|
|||
|
|
@ -1486,6 +1486,7 @@ The precedence rule for preset options is as follows:
|
|||
|
||||
We also offer additional options that are exclusive to presets (these aren't treated as command-line arguments):
|
||||
- `load-on-startup` (boolean): Controls whether the model loads automatically when the server starts
|
||||
- `stop-timeout` (int, seconds): After requested unload, wait for this many seconds before forcing termination (default: 10)
|
||||
|
||||
### Routing requests
|
||||
|
||||
|
|
@ -1574,8 +1575,7 @@ Payload:
|
|||
|
||||
```json
|
||||
{
|
||||
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
|
||||
"extra_args": ["-n", "128", "--top-k", "4"]
|
||||
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@
|
|||
#include <limits.h>
|
||||
#endif
|
||||
|
||||
#define DEFAULT_STOP_TIMEOUT 10 // seconds
|
||||
|
||||
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
|
||||
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
|
||||
|
||||
|
|
@ -203,13 +205,14 @@ void server_models::load_models() {
|
|||
// convert presets to server_model_meta and add to mapping
|
||||
for (const auto & preset : final_presets) {
|
||||
server_model_meta meta{
|
||||
/* preset */ preset.second,
|
||||
/* name */ preset.first,
|
||||
/* port */ 0,
|
||||
/* status */ SERVER_MODEL_STATUS_UNLOADED,
|
||||
/* last_used */ 0,
|
||||
/* args */ std::vector<std::string>(),
|
||||
/* exit_code */ 0
|
||||
/* preset */ preset.second,
|
||||
/* name */ preset.first,
|
||||
/* port */ 0,
|
||||
/* status */ SERVER_MODEL_STATUS_UNLOADED,
|
||||
/* last_used */ 0,
|
||||
/* args */ std::vector<std::string>(),
|
||||
/* exit_code */ 0,
|
||||
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
||||
};
|
||||
add_model(std::move(meta));
|
||||
}
|
||||
|
|
@ -227,6 +230,20 @@ void server_models::load_models() {
|
|||
}
|
||||
}
|
||||
|
||||
// handle custom stop-timeout option
|
||||
for (auto & [name, inst] : mapping) {
|
||||
std::string val;
|
||||
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
|
||||
try {
|
||||
inst.meta.stop_timeout = std::stoi(val);
|
||||
} catch (...) {
|
||||
SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
|
||||
val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
|
||||
inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// load any autoload models
|
||||
std::vector<std::string> models_to_load;
|
||||
for (const auto & [name, inst] : mapping) {
|
||||
|
|
@ -362,7 +379,7 @@ void server_models::unload_lru() {
|
|||
int64_t lru_last_used = ggml_time_ms();
|
||||
size_t count_active = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
for (const auto & m : mapping) {
|
||||
if (m.second.meta.is_active()) {
|
||||
count_active++;
|
||||
|
|
@ -376,6 +393,13 @@ void server_models::unload_lru() {
|
|||
if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
|
||||
SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
|
||||
unload(lru_model_name);
|
||||
// wait for unload to complete
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
cv.wait(lk, [this, &lru_model_name]() {
|
||||
return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -436,38 +460,83 @@ void server_models::load(const std::string & name) {
|
|||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
|
||||
// read stdout/stderr and forward to main server log
|
||||
bool state_received = false; // true if child state received
|
||||
FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
|
||||
if (p_stdout_stderr) {
|
||||
char buffer[4096];
|
||||
while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
|
||||
LOG("[%5d] %s", port, buffer);
|
||||
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
|
||||
// child process is ready
|
||||
this->update_status(name, SERVER_MODEL_STATUS_LOADED);
|
||||
state_received = true;
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
FILE * stdin_file = subprocess_stdin(child_proc.get());
|
||||
FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
|
||||
|
||||
std::thread log_thread([&]() {
|
||||
// read stdout/stderr and forward to main server log
|
||||
// also handle status report from child process
|
||||
bool state_received = false; // true if child state received
|
||||
if (stdout_file) {
|
||||
char buffer[4096];
|
||||
while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
|
||||
LOG("[%5d] %s", port, buffer);
|
||||
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
|
||||
// child process is ready
|
||||
this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
|
||||
state_received = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
|
||||
}
|
||||
} else {
|
||||
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
|
||||
}
|
||||
});
|
||||
|
||||
std::thread stopping_thread([&]() {
|
||||
// thread to monitor stopping signal
|
||||
auto is_stopping = [this, &name]() {
|
||||
return this->stopping_models.find(name) != this->stopping_models.end();
|
||||
};
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
this->cv_stop.wait(lk, is_stopping);
|
||||
}
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
// send interrupt to child process
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
// wait to stop gracefully or timeout
|
||||
int64_t start_time = ggml_time_ms();
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lk(this->mutex);
|
||||
if (!is_stopping()) {
|
||||
return; // already stopped
|
||||
}
|
||||
int64_t elapsed = ggml_time_ms() - start_time;
|
||||
if (elapsed >= stop_timeout * 1000) {
|
||||
// timeout, force kill
|
||||
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
|
||||
subprocess_terminate(child_proc.get());
|
||||
return;
|
||||
}
|
||||
this->cv_stop.wait_for(lk, std::chrono::seconds(1));
|
||||
}
|
||||
});
|
||||
|
||||
// we reach here when the child process exits
|
||||
// note: we cannot join() prior to this point because it will close stdin_file
|
||||
if (log_thread.joinable()) {
|
||||
log_thread.join();
|
||||
}
|
||||
|
||||
// stop the timeout monitoring thread
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(this->mutex);
|
||||
stopping_models.erase(name);
|
||||
cv_stop.notify_all();
|
||||
}
|
||||
if (stopping_thread.joinable()) {
|
||||
stopping_thread.join();
|
||||
}
|
||||
|
||||
// get the exit code
|
||||
int exit_code = 0;
|
||||
subprocess_join(child_proc.get(), &exit_code);
|
||||
subprocess_destroy(child_proc.get());
|
||||
// update PID and status
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
auto & meta = it->second.meta;
|
||||
meta.exit_code = exit_code;
|
||||
meta.status = SERVER_MODEL_STATUS_UNLOADED;
|
||||
}
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// update status and exit code
|
||||
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
|
||||
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
||||
});
|
||||
|
||||
|
|
@ -488,22 +557,14 @@ void server_models::load(const std::string & name) {
|
|||
cv.notify_all();
|
||||
}
|
||||
|
||||
static void interrupt_subprocess(FILE * stdin_file) {
|
||||
// because subprocess.h does not provide a way to send SIGINT,
|
||||
// we will send a command to the child process to exit gracefully
|
||||
if (stdin_file) {
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.is_active()) {
|
||||
SRV_INF("unloading model instance name=%s\n", name.c_str());
|
||||
interrupt_subprocess(it->second.stdin_file);
|
||||
stopping_models.insert(name);
|
||||
cv_stop.notify_all();
|
||||
// status change will be handled by the managing thread
|
||||
} else {
|
||||
SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
|
||||
|
|
@ -518,7 +579,8 @@ void server_models::unload_all() {
|
|||
for (auto & [name, inst] : mapping) {
|
||||
if (inst.meta.is_active()) {
|
||||
SRV_INF("unloading model instance name=%s\n", name.c_str());
|
||||
interrupt_subprocess(inst.stdin_file);
|
||||
stopping_models.insert(name);
|
||||
cv_stop.notify_all();
|
||||
// status change will be handled by the managing thread
|
||||
}
|
||||
// moving the thread to join list to avoid deadlock
|
||||
|
|
@ -532,16 +594,15 @@ void server_models::unload_all() {
|
|||
}
|
||||
}
|
||||
|
||||
void server_models::update_status(const std::string & name, server_model_status status) {
|
||||
// for now, we only allow updating to LOADED status
|
||||
if (status != SERVER_MODEL_STATUS_LOADED) {
|
||||
throw std::runtime_error("invalid status value");
|
||||
}
|
||||
auto meta = get_meta(name);
|
||||
if (meta.has_value()) {
|
||||
meta->status = status;
|
||||
update_meta(name, meta.value());
|
||||
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
auto & meta = it->second.meta;
|
||||
meta.status = status;
|
||||
meta.exit_code = exit_code;
|
||||
}
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::wait_until_loaded(const std::string & name) {
|
||||
|
|
@ -568,6 +629,7 @@ bool server_models::ensure_model_loaded(const std::string & name) {
|
|||
load(name);
|
||||
}
|
||||
|
||||
// for loading state
|
||||
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
|
||||
wait_until_loaded(name);
|
||||
|
||||
|
|
@ -795,7 +857,7 @@ void server_models_routes::init_routes() {
|
|||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (model->status != SERVER_MODEL_STATUS_LOADED) {
|
||||
if (!model->is_active()) {
|
||||
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
/**
|
||||
* state diagram:
|
||||
|
|
@ -56,6 +57,7 @@ struct server_model_meta {
|
|||
int64_t last_used = 0; // for LRU unloading
|
||||
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
|
||||
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
|
||||
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
|
||||
|
||||
bool is_active() const {
|
||||
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
|
||||
|
|
@ -83,6 +85,10 @@ private:
|
|||
std::condition_variable cv;
|
||||
std::map<std::string, instance_t> mapping;
|
||||
|
||||
// for stopping models
|
||||
std::condition_variable cv_stop;
|
||||
std::set<std::string> stopping_models;
|
||||
|
||||
common_preset_context ctx_preset;
|
||||
|
||||
common_params base_params;
|
||||
|
|
@ -119,7 +125,7 @@ public:
|
|||
void unload_all();
|
||||
|
||||
// update the status of a model instance (thread-safe)
|
||||
void update_status(const std::string & name, server_model_status status);
|
||||
void update_status(const std::string & name, server_model_status status, int exit_code);
|
||||
|
||||
// wait until the model instance is fully loaded (thread-safe)
|
||||
// return when the model is loaded or failed to load
|
||||
|
|
|
|||
Loading…
Reference in New Issue