This commit is contained in:
oliver 2026-03-16 09:37:11 +08:00 committed by GitHub
commit efa26be955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 2 deletions

View File

@ -804,6 +804,12 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
return false; return false;
} }
// Skip speculative decoding for embedding models
// Embedding models don't have output logits needed for speculative decoding
if (llama_pooling_type(ctx_tgt) != LLAMA_POOLING_TYPE_NONE) {
return false;
}
bool res = true; bool res = true;
llama_memory_clear(mem, true); llama_memory_clear(mem, true);

View File

@ -1767,12 +1767,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract sequence embeddings (cleared before processing each batch) // extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
// For V-L models, the embedding output tensor may have different dimensions
// The embedding dimension is determined by the tensor shape (ne[0]), not by model hparams
const uint32_t n_embd_tensor = t_embd->ne[0];
// Use the tensor's embedding dimension if valid, otherwise fall back to model dimension
const uint32_t n_embd_to_use = n_embd_tensor > 0 ? n_embd_tensor : n_embd;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id]; const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd); embd_seq_out[seq_id].resize(n_embd_to_use);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(),
(n_embd_to_use*seq_idx)*sizeof(float), n_embd_to_use*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK: