kq_mask naming fix

This commit is contained in:
Mustafa Cavus 2026-01-15 14:38:53 -08:00
parent d3649c11cb
commit d7dccf887b
3 changed files with 7 additions and 7 deletions

View File

@ -324,7 +324,7 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
int layer = extract_layer_from_name(cache_k->name);
auto * mask = node->src[3];
std::string mask_name(mask->name);
assert(mask_name.find("KQ_mask") == 0);
assert(mask_name.find("self_kq_mask") == 0);
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
model_params.swa_layers.push_back(layer);
@ -392,7 +392,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
} else if (name == "inp_out_ids") {
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1};
} else if (name.find("KQ_mask") == 0) {
} else if (name.find("self_kq_mask") == 0) {
if (m_is_static) {
input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
} else if (m_is_stateful) {

View File

@ -109,8 +109,8 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
}
};
create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static());
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
create_sliced_mask("self_kq_mask", "KQ_mask_sliced", ggml_model_decoder.is_static());
create_sliced_mask("self_kq_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
}
void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {

View File

@ -525,7 +525,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
return input_tensor;
}
if (param_name.find("KQ_mask") == 0) {
if (param_name.find("self_kq_mask") == 0) {
size_t context_size = ggml_decoder->get_ctx_size();
std::vector<float> padded_data = pad_input<float>(ggml_tensor, 1, context_size, -INFINITY);
ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size});
@ -591,7 +591,7 @@ ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggm
return input_tensor;
}
if (param_name.find("KQ_mask") == 0) {
if (param_name.find("self_kq_mask") == 0) {
size_t cols = ggml_tensor->ne[0];
size_t rows = ggml_tensor->ne[1];
float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols;
@ -645,7 +645,7 @@ void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor
<< std::endl;
switch (tensor.get_element_type()) {
case ov::element::f32: {
if (name.find("KQ_mask") == std::string::npos) {
if (name.find("self_kq_mask") == std::string::npos) {
std::cout << *(tensor.data<float>()) << std::endl;
} else {
size_t rows = tensor.get_shape()[2];