feat: add EAGLE3 speculative decoding support

EAGLE3 is an encoder-decoder based speculative decoding method:
- Extracts features from target model at specific layers
- Uses feature fusion layer to compress target features
- Generates draft tokens with single-layer decoder
- Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:
- Add LLM_ARCH_EAGLE3 architecture
- Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
- Add feature extraction from target model layers
- Add g_embeddings handling for decoder input
- Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
- Add --eagle3 flag for speculative-simple example
- Add EAGLE3 model conversion in convert_hf_to_gguf.py
This commit is contained in:
ruixiangw 2025-12-14 18:12:33 +00:00
parent 5c8a717128
commit 8fac4b1cc8
25 changed files with 1119 additions and 31 deletions

View File

@ -3007,6 +3007,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN"));
add_opt(common_arg(
{"--eagle3"},
"use EAGLE3 speculative decoding with the draft model",
[](common_params & params) {
params.speculative.eagle3 = true;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),

View File

@ -241,6 +241,8 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
bool eagle3 = false; // use EAGLE3 speculative decoding
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

View File

@ -22,6 +22,11 @@ struct common_speculative {
llama_tokens prompt_dft;
bool vocab_dft_compatible = true; // whether retokenization is needed
std::map<std::string, std::string> tgt_dft_replacements = {};
// EAGLE3 specific
struct llama_context * eagle3_encoder = nullptr;
struct llama_context * eagle3_decoder = nullptr;
int32_t eagle3_n_past = 0; // number of verified positions in decoder KV cache
};
struct common_speculative * common_speculative_init(
@ -74,6 +79,35 @@ struct common_speculative * common_speculative_init(
return result;
}
struct common_speculative * common_speculative_init_eagle3(
struct llama_context * ctx_tgt,
struct llama_context * ctx_encoder,
struct llama_context * ctx_decoder) {
auto * result = new common_speculative {
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ nullptr, // Not used for EAGLE3
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_decoder), 0, 1),
/* .prompt_dft = */ {},
/* .vocab_dft_compatible = */ true, // EAGLE3 uses same vocab
/* .tgt_dft_replacements = */ {},
/* .eagle3_encoder = */ ctx_encoder,
/* .eagle3_decoder = */ ctx_decoder,
};
// Initialize sampler for EAGLE3 decoder
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10; // set 1 for greedy sampling (argmax) to match vLLM's default behavior but >1 always gets higher acceptance rate for eagle3
params.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
result->smpl = common_sampler_init(llama_get_model(ctx_decoder), params);
}
return result;
}
void common_speculative_free(struct common_speculative * spec) {
if (spec == nullptr) {
return;
@ -81,6 +115,14 @@ void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
// EAGLE3 cleanup
if (spec->eagle3_encoder) {
llama_free(spec->eagle3_encoder);
}
if (spec->eagle3_decoder) {
llama_free(spec->eagle3_decoder);
}
llama_batch_free(spec->batch);
delete spec;
@ -181,12 +223,169 @@ static std::string replace_to_tgt(
return result;
}
// EAGLE3 Draft Generation with KV Cache Reuse
//
// ============================================================================
// EXAMPLE: Two rounds of speculative decoding
// ============================================================================
//
// ROUND 1 (Initial):
// Prompt: [t0, t1, t2, t3, t4], target generates t5
// prompt_tgt = [t0, t1, t2, t3, t4], id_last = t5 (GENERATED)
// n = 5, n_past = 0, n_new = 5
//
// Step 1: Encoder
// features: [f0, f1, f2, f3, f4] → g_embeddings: [g0, g1, g2, g3, g4]
//
// Step 2: Decoder batch (positions 0-4)
// tokens: [t1, t2, t3, t4, t5] ← prompt[1:] + id_last
// g_embd: [g0, g1, g2, g3, g4]
// positions: [0, 1, 2, 3, 4 ]
// → KV cache: [0, 1, 2, 3, 4]
// → sample d1 from logits[4]
//
// Step 3: Autoregressive (positions 5, 6, ...)
// pos 5: token=d1, g_embd=prenorm[4] → KV cache: [0,1,2,3,4,5] → d2
// pos 6: token=d2, g_embd=prenorm → KV cache: [0,1,2,3,4,5,6] → d3
//
// Output: [d1, d2, d3]
// Update: n_past = 5 (verified positions from batch decode)
//
// ROUND 2 (assuming d1 accepted, d2/d3 rejected):
// prompt_tgt = [t0, t1, t2, t3, t4, t5, d1], id_last = t6 (new target output)
// n = 7, n_past = 5, n_new = 2
//
// Step 1: Clear KV cache [5, inf) - remove draft positions
// KV cache: [0, 1, 2, 3, 4] (reuse from round 1!)
//
// Step 2: Encoder (only new tokens)
// features: [f5, f6] → g_embeddings: [g5, g6]
//
// Step 3: Decoder batch (only new positions 5-6)
// tokens: [d1, t6] (prompt_tgt[6], id_last)
// g_embd: [g5, g6]
// positions: [5, 6 ]
// → KV cache: [0,1,2,3,4] + [5,6] = [0,1,2,3,4,5,6]
// → sample d1' from logits[1] (last position in batch)
//
// Step 4: Autoregressive...
//
// ============================================================================
//
// Key insight: Decoder KV cache stores K/V computed from (tok_embd + g_embd).
// For verified positions, both tok_embd and g_embd are fixed (encoder output),
// so KV cache can be reused. Draft positions use prenorm as g_embd, which
// differs from encoder output, so they must be cleared and recomputed.
//
static llama_tokens gen_eagle3_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last) {
auto * ctx_tgt = spec->ctx_tgt;
auto * ctx_encoder = spec->eagle3_encoder;
auto * ctx_decoder = spec->eagle3_decoder;
auto * smpl = spec->smpl;
auto & batch = spec->batch;
const int n_embd = llama_model_n_embd(llama_get_model(ctx_encoder));
const int n = (int)prompt_tgt.size();
const int n_new = n - spec->eagle3_n_past;
GGML_ASSERT(n >= 1 && "prompt_tgt is empty");
GGML_ASSERT(n_new >= 1 && "must have at least 1 new token");
// Clear draft positions from decoder KV cache [n_past, inf)
llama_memory_seq_rm(llama_get_memory(ctx_decoder), 0, spec->eagle3_n_past, -1);
// Encoder: features → g_embeddings
const float * features = llama_get_eagle3_target_features(ctx_tgt);
GGML_ASSERT(features && "no target features");
llama_batch enc_batch = {
/*.n_tokens =*/ n_new,
/*.token =*/ nullptr,
/*.embd =*/ const_cast<float*>(features),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
GGML_ASSERT(llama_encode(ctx_encoder, enc_batch) == 0);
const float * g_embd = llama_get_embeddings(ctx_encoder);
GGML_ASSERT(g_embd && "encoder output failed");
// Decoder batch: process new tokens with KV cache reuse
llama_set_eagle3_g_embeddings(ctx_decoder, g_embd, n_embd, n_new);
common_batch_clear(batch);
for (int i = 0; i < n_new; i++) {
const int pos = spec->eagle3_n_past + i;
const llama_token tok = (pos < n - 1) ? prompt_tgt[pos + 1] : id_last;
common_batch_add(batch, tok, pos, {0}, true);
}
GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0);
spec->eagle3_n_past = n; // update verified positions
// Sample draft tokens
llama_tokens result;
common_sampler_reset(smpl);
// Sample and check probability (consistent with standard speculative decoding)
auto sample_and_check = [&](int idx) -> bool {
common_sampler_sample(smpl, ctx_decoder, idx);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
result.push_back(id);
return cur_p->data[0].p >= params.p_min;
};
// First draft token from batch decode
if (!sample_and_check(n_new - 1)) {
return result;
}
// Autoregressive: use prenorm as g_embd (-1 = last output)
const float * prenorm = llama_get_embeddings_ith(ctx_decoder, -1);
for (int i = 1; i < params.n_draft; i++) {
GGML_ASSERT(prenorm && "prenorm failed");
llama_set_eagle3_g_embeddings(ctx_decoder, prenorm, n_embd, 1);
common_batch_clear(batch);
common_batch_add(batch, result.back(), n - 1 + i, {0}, true);
GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0);
prenorm = llama_get_embeddings_ith(ctx_decoder, -1);
if (!sample_and_check(0)) {
break;
}
}
return result;
}
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last) {
// EAGLE3 path
if (spec->eagle3_encoder && spec->eagle3_decoder) {
return gen_eagle3_draft(spec, params, prompt_tgt_main_model, id_last);
}
// Standard draft model path
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;

View File

@ -17,6 +17,13 @@ struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft
);
// EAGLE3: Initialize speculative decoding with EAGLE3 encoder and decoder contexts
struct common_speculative * common_speculative_init_eagle3(
struct llama_context * ctx_tgt,
struct llama_context * ctx_encoder,
struct llama_context * ctx_decoder
);
void common_speculative_free(struct common_speculative * spec);
bool common_speculative_are_compatible(

View File

@ -97,6 +97,7 @@ class ModelBase:
metadata_override: Path | None
dir_model_card: Path
remote_hf_model_id: str | None
target_model_dir: Path | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -116,7 +117,7 @@ class ModelBase:
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
sentence_transformers_dense_modules: bool = False, target_model_dir: Path | None = None):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@ -135,6 +136,7 @@ class ModelBase:
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.target_model_dir = target_model_dir
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
@ -2373,7 +2375,55 @@ class LlamaModel(TextModel):
if self.hf_arch == "VLlama3ForCausalLM":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
# detect EAGLE-3 llama checkpoint
if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1:
self.is_eagle3 = True
self.model_arch = gguf.MODEL_ARCH.EAGLE3
logger.info("Detected EAGLE-3 draft model, switching to EAGLE3 architecture")
# Re-initialize tensor_map with EAGLE3 architecture
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
# Update gguf_writer architecture
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
if not hasattr(self, 'target_model_dir') or not self.target_model_dir:
raise ValueError(
"EAGLE3 model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory to read config.json"
)
# Read both EAGLE3 raw config and target model config
with open(self.dir_model / "config.json", 'r', encoding='utf-8') as f:
eagle3_raw_config = json.load(f)
with open(self.target_model_dir / "config.json", 'r', encoding='utf-8') as f:
target_config = json.load(f)
# EAGLE3 extract_layers
target_num_layers = target_config["num_hidden_layers"]
extract_layers = [2, target_num_layers // 2, target_num_layers - 3]
logger.info(f"EAGLE3: extract_layers = {extract_layers} (target model has {target_num_layers} layers)")
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.extract_layers", extract_layers)
# EAGLE3 target_hidden_size: prefer EAGLE3 config, fallback to target config
if "target_hidden_size" in eagle3_raw_config and eagle3_raw_config["target_hidden_size"] is not None:
target_hidden_size = eagle3_raw_config["target_hidden_size"]
logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from EAGLE3 config)")
else:
target_hidden_size = target_config["hidden_size"]
logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from target model config)")
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
def set_vocab(self):
# For EAGLE-3 models, use tokenizer from target model if provided
if hasattr(self, 'is_eagle3') and self.is_eagle3:
if self.target_model_dir is None:
raise ValueError(
"EAGLE-3 draft model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory containing the tokenizer."
)
logger.info(f"EAGLE-3: Using tokenizer from target model: {self.target_model_dir}")
# Temporarily swap dir_model to load tokenizer from target model
original_dir_model = self.dir_model
self.dir_model = self.target_model_dir
if self.is_mistral_format:
return self._set_vocab_mistral()
@ -2391,6 +2441,10 @@ class LlamaModel(TextModel):
# Llama 3
self._set_vocab_gpt2()
# Restore original dir_model for EAGLE-3
if hasattr(self, 'is_eagle3') and self.is_eagle3:
self.dir_model = original_dir_model
# Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
if self.hparams.get("vocab_size", 32000) == 32016:
special_vocab = gguf.SpecialVocab(
@ -2435,7 +2489,45 @@ class LlamaModel(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
tensors = super().index_tensors(remote_hf_model_id)
# EAGLE-3 detection: check hparams directly (before self.is_eagle3 is set)
if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1:
logger.info("EAGLE-3: Renaming midlayer.* to model.layers.0.*")
new_tensors = {}
# EAGLE-3: rename midlayer.* to model.layers.0.* for compatibility with llama model
for name, gen in tensors.items():
if name.startswith("midlayer."):
new_name = "model.layers.0." + name[len("midlayer."):]
new_tensors[new_name] = gen
else:
new_tensors[name] = gen
return new_tensors
else:
return tensors
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Eagle-3 llama checkpoint special handling
if hasattr(self, 'is_eagle3') and self.is_eagle3:
# Eagle-3 llama checkpoint special weights handling
# fc.weight: feature fusion layer
if name == "fc.weight":
return [(name, data_torch)]
# d2t: draft to target vocabulary mapping
elif name == "d2t":
# Skip parent class processing (store for manual handling in prepare_tensors)
if not hasattr(self, '_eagle3_int_tensors'):
self._eagle3_int_tensors = {}
self._eagle3_int_tensors[name] = data_torch
return []
# t2d: target to draft vocabulary mapping (not used, skip completely)
elif name == "t2d":
return []
# hidden_norm: EAGLE-3 specific layer normalization
elif name == "model.layers.0.hidden_norm.weight":
return [("blk.0.hidden_norm.weight", data_torch)]
n_head = self.find_hparam(["n_heads", "num_attention_heads"])
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
@ -2538,8 +2630,26 @@ class LlamaModel(TextModel):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
def prepare_tensors(self):
# EAGLE-3: collect original dtypes BEFORE parent class converts them to F32
eagle3_original_dtypes = {}
if hasattr(self, 'is_eagle3') and self.is_eagle3:
for name, data_torch in self.get_tensors():
if name == "d2t":
eagle3_original_dtypes[name] = data_torch.dtype
super().prepare_tensors()
if hasattr(self, 'is_eagle3') and self.is_eagle3 and hasattr(self, '_eagle3_int_tensors'):
for name, data_torch in self._eagle3_int_tensors.items():
old_dtype = eagle3_original_dtypes.get(name, data_torch.dtype)
# Keep as int64 to match original torch tensor dtype
data = data_torch.to(torch.int64).numpy()
data_qtype = gguf.GGMLQuantizationType.I64
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
logger.info(f"{name + ',':<30} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
@ -10125,6 +10235,7 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
torch.int64: np.int64,
}
# only used when byteswapping data. Only correct size is needed
@ -10285,6 +10396,10 @@ def parse_args() -> argparse.Namespace:
"--no-tensor-first-split", action="store_true",
help="do not add tensors to the first split (disabled by default)"
)
parser.add_argument(
"--target-model-dir", type=str, default=None,
help="directory containing target model tokenizer (for EAGLE-3 draft models that don't have their own tokenizer)",
)
parser.add_argument(
"--metadata", type=Path,
help="Specify the path for an authorship metadata override file"
@ -10457,7 +10572,8 @@ def main() -> None:
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
target_model_dir=Path(args.target_model_dir) if args.target_model_dir else None
)
if args.vocab_only:

View File

@ -4,6 +4,7 @@
#include "speculative.h"
#include "log.h"
#include "llama.h"
#include "chat.h"
#include <cstdio>
#include <cstring>
@ -34,16 +35,42 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
//llama_model * model_dft = NULL;
llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
auto llama_init_tgt = common_init_from_params(params);
// EAGLE3 specific contexts
llama_context * ctx_encoder = NULL;
llama_context * ctx_decoder = NULL;
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// For EAGLE3: load both draft model and target model
if (params.speculative.eagle3) {
llama_model_params dft_mp = llama_model_default_params();
dft_mp.n_gpu_layers = params.speculative.n_gpu_layers;
model_dft = llama_model_load_from_file(params.speculative.model.path.c_str(), dft_mp);
if (!model_dft) {
LOG_ERR("failed to load EAGLE3 draft model\n");
return 1;
}
llama_model_params tgt_mp = llama_model_default_params();
tgt_mp.n_gpu_layers = params.n_gpu_layers;
model_tgt = llama_model_load_from_file(params.model.path.c_str(), tgt_mp);
if (!model_tgt) {
LOG_ERR("failed to load target model\n");
return 1;
}
llama_context_params tcp = common_context_params_to_llama(params);
tcp.eagle3_model = model_dft; // Enable feature extraction
ctx_tgt = llama_init_from_model(model_tgt, tcp);
} else {
// Standard load the target model
auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
}
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
@ -61,18 +88,57 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto llama_init_dft = common_init_from_params(params);
if (params.speculative.eagle3) {
// EAGLE3: create encoder and decoder contexts
llama_context_params enc_params = common_context_params_to_llama(params);
enc_params.embeddings = true;
ctx_encoder = llama_init_from_model(model_dft, enc_params);
if (!ctx_encoder) {
LOG_ERR("failed to create EAGLE3 encoder context\n");
return 1;
}
//model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
llama_context_params dec_params = common_context_params_to_llama(params);
dec_params.target_model = model_tgt;
dec_params.embeddings = true;
ctx_decoder = llama_init_from_model(model_dft, dec_params);
if (!ctx_decoder) {
LOG_ERR("failed to create EAGLE3 decoder context\n");
return 1;
}
} else {
// Standard: load draft model context
auto llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
}
}
// Apply chat template for EAGLE3 if available which can increase the acceptance rate
std::string prompt = params.prompt;
if (params.speculative.eagle3) {
auto chat_templates = common_chat_templates_init(model_tgt, params.chat_template);
if (common_chat_templates_was_explicit(chat_templates.get())) {
std::vector<common_chat_msg> chat_msgs;
common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = params.prompt;
chat_msgs.push_back(user_msg);
common_chat_templates_inputs inputs;
inputs.messages = chat_msgs;
inputs.add_generation_prompt = true;
prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt;
LOG_INF("%s: EAGLE3 chat template applied\n", __func__);
}
}
// Tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
inp = common_tokenize(ctx_tgt, prompt, true, true);
if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
@ -115,26 +181,52 @@ int main(int argc, char ** argv) {
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
llama_token id_last;
llama_tokens prompt_tgt;
int n_past;
// note: keep the last token separate!
llama_token id_last = inp.back();
if (params.speculative.eagle3) {
// Target model decodes full prompt and sample first token and intermediate features are extracted
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size()));
// all tokens currently in the target context
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
id_last = common_sampler_sample(smpl, ctx_tgt, -1);
common_sampler_accept(smpl, id_last, true);
LOG("%s", common_token_to_piece(ctx_tgt, id_last).c_str());
n_predict++;
int n_past = inp.size() - 1;
// all tokens currently in the target context
prompt_tgt.assign(inp.begin(), inp.end());
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
n_past = inp.size();
} else {
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
// note: keep the last token separate!
id_last = inp.back();
// all tokens currently in the target context
prompt_tgt.assign(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
n_past = inp.size() - 1;
}
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
for (auto &pair : params.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
struct common_speculative * spec = NULL;
if (params.speculative.eagle3) {
spec = common_speculative_init_eagle3(ctx_tgt, ctx_encoder, ctx_decoder);
} else {
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
spec = common_speculative_init(ctx_tgt, ctx_dft);
for (auto &pair : params.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
}
}
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
@ -249,7 +341,14 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
llama_perf_context_print(ctx_dft);
if (ctx_dft) {
llama_perf_context_print(ctx_dft);
} else if (ctx_encoder && ctx_decoder) {
LOG_INF(" Eagle3 Draft encoder:\n");
llama_perf_context_print(ctx_encoder);
LOG_INF("\nEagle3 Draft decoder:\n");
llama_perf_context_print(ctx_decoder);
}
LOG_INF("\n");
LOG_INF("target:\n\n");

View File

@ -629,6 +629,7 @@ extern "C" {
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_SYNC = 16, // ...forces a new split/sync point in the scheduler (e.g. for EAGLE3 decoder)
};
enum ggml_tri_type {
@ -853,6 +854,7 @@ extern "C" {
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
GGML_API void ggml_set_sync(struct ggml_tensor * tensor); // force sync point in scheduler
//
// operations on tensors with backpropagation

View File

@ -1202,6 +1202,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
}
}
// check if this node requires a sync point (e.g. for EAGLE3 parallel path fix)
if (node->flags & GGML_TENSOR_FLAG_SYNC) {
need_new_split = true;
}
if (node_backend_id != cur_backend_id || need_new_split) {
split->i_end = i;
i_split++;
@ -1576,6 +1581,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
// If any node in this split has SYNC flag, synchronize after compute
// This ensures the sync node is complete before next split (e.g. for EAGLE3 parallel path sync fix)
for (int j = 0; j < split->graph.n_nodes; j++) {
if (split->graph.nodes[j]->flags & GGML_TENSOR_FLAG_SYNC) {
ggml_backend_synchronize(split_backend);
break;
}
}
} else {
// similar to ggml_backend_compare_graph_backend
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {

View File

@ -7451,6 +7451,10 @@ void ggml_set_loss(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_LOSS;
}
void ggml_set_sync(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_SYNC;
}
////////////////////////////////////////////////////////////////////////////////
void ggml_quantize_init(enum ggml_type type) {

View File

@ -147,6 +147,8 @@ class Keys:
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
EAGLE3_EXTRACT_LAYERS = "{arch}.extract_layers"
EAGLE3_TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@ -446,6 +448,7 @@ class MODEL_ARCH(IntEnum):
RND1 = auto()
PANGU_EMBED = auto()
MISTRAL3 = auto()
EAGLE3 = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -710,6 +713,10 @@ class MODEL_TENSOR(IntEnum):
NEXTN_HNORM = auto()
NEXTN_SHARED_HEAD_HEAD = auto()
NEXTN_SHARED_HEAD_NORM = auto()
# EAGLE3 specific tensors
EAGLE3_FC = auto() # feature fusion layer
EAGLE3_HIDDEN_NORM = auto() # hidden normalization
EAGLE3_D2T = auto() # draft to target vocabulary mapping
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -820,6 +827,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.EAGLE3: "eagle3",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -1082,6 +1090,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
MODEL_TENSOR.EAGLE3_FC: "fc",
MODEL_TENSOR.EAGLE3_HIDDEN_NORM: "blk.{bid}.hidden_norm",
MODEL_TENSOR.EAGLE3_D2T: "d2t",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -3094,6 +3105,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.EAGLE3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.EAGLE3_FC,
MODEL_TENSOR.EAGLE3_HIDDEN_NORM,
MODEL_TENSOR.EAGLE3_D2T,
],
# TODO
}

View File

@ -363,6 +363,13 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// EAGLE3 extraction configuration
// When eagle3_model is set, layer extraction is automatically enabled
const struct llama_model * eagle3_model; // EAGLE3 model to read extract_layers configuration from
// If non-NULL, enables automatic feature extraction
const struct llama_model * target_model; // reference to target model
// only used to share embedding layer with eagle3 model
};
// model quantization parameters
@ -846,6 +853,23 @@ extern "C" {
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);
//
// EAGLE3 draft model support
//
// Get pointer to target model features extracted for EAGLE3 encoder
// Returns NULL if no features are available
// Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions
LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx);
// Set g_embeddings from EAGLE3 encoder output for decoder input
// g_embd: pointer to encoder output embeddings
LLAMA_API void llama_set_eagle3_g_embeddings(
struct llama_context * ctx,
const float * g_embd,
int32_t n_embd,
int32_t n_tokens);
//
// Decoding
//

View File

@ -58,6 +58,7 @@ add_library(llama
models/deepseek2.cpp
models/dots1.cpp
models/dream.cpp
models/eagle3.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
models/exaone.cpp

View File

@ -112,6 +112,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_EAGLE3, "eagle3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -245,6 +246,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_EAGLE3_EXTRACT_LAYERS, "%s.extract_layers" },
{ LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
// sentence-transformers dense modules feature dims
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
@ -2540,6 +2544,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_EAGLE3,
{
// Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own, Llama 3.1 8B EAGLE3 uses target model's)
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, // Optional - only if EAGLE3 config has rope_scaling
// Single decoder layer (blk.0)
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
// EAGLE-3 specific layers
{ LLM_TENSOR_EAGLE3_HIDDEN_NORM, "blk.%d.hidden_norm" },
{ LLM_TENSOR_EAGLE3_FC, "fc" },
{ LLM_TENSOR_EAGLE3_D2T, "d2t" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -2742,6 +2770,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// EAGLE-3 tensors
{LLM_TENSOR_EAGLE3_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_EAGLE3_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_EAGLE3_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

View File

@ -117,6 +117,7 @@ enum llm_arch {
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
LLM_ARCH_EAGLE3,
};
enum llm_kv {
@ -287,6 +288,9 @@ enum llm_kv {
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
LLM_KV_EAGLE3_EXTRACT_LAYERS,
LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE,
LLM_KV_SHORTCONV_L_CACHE,
LLM_KV_XIELU_ALPHA_N,
@ -492,6 +496,9 @@ enum llm_tensor {
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
LLM_TENSOR_EAGLE3_FC, // eagle3: feature fusion layer
LLM_TENSOR_EAGLE3_HIDDEN_NORM, // eagle3: additional normalization layer
LLM_TENSOR_EAGLE3_D2T, // eagle3: draft to target vocabulary mapping
};
enum llm_tensor_layer {

View File

@ -135,6 +135,7 @@ llama_context::llama_context(
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
cparams.eagle3_extract_enabled = (params.eagle3_model != nullptr); // auto-enable if eagle3_model is provided
{
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
@ -333,6 +334,30 @@ llama_context::llama_context(
cross.v_embd.clear();
// Initialize EAGLE3 feature extraction configuration
if (cparams.eagle3_extract_enabled) {
// Feature extraction layers configuration must come from EAGLE3 model
if (!params.eagle3_model) {
LLAMA_LOG_ERROR("%s: EAGLE3 extraction enabled but eagle3_model not provided\n", __func__);
throw std::runtime_error("EAGLE3 extraction requires eagle3_model parameter");
}
const auto & eagle3_hparams = params.eagle3_model->hparams;
// Copy feature extraction layer indices from EAGLE3 model's hparams
eagle3.extract_layer_indices.assign(
eagle3_hparams.eagle3_extract_layers.begin(),
eagle3_hparams.eagle3_extract_layers.end()
);
// Allocate tensors array for extraction
eagle3.extract_tensors.resize(eagle3.extract_layer_indices.size(), nullptr);
LLAMA_LOG_INFO("%s: EAGLE3 extraction enabled for layers [%d, %d, %d]\n", __func__,
eagle3.extract_layer_indices[0],
eagle3.extract_layer_indices[1],
eagle3.extract_layer_indices[2]);
}
// avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs;
@ -832,6 +857,14 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
//const auto t_start_us = ggml_time_us();
res->set_inputs(&ubatch);
// EAGLE3: Fill g_embeddings for decoder input
if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) {
ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings");
if (g_embd) {
ggml_backend_tensor_set(g_embd, eagle3.g_embeddings.data(), 0, ggml_nbytes(g_embd));
}
}
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
}
@ -843,6 +876,11 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
return nullptr;
}
// EAGLE3: Extract intermediate layer features after graph execution
if (cparams.eagle3_extract_enabled && !eagle3.extract_tensors.empty()) {
extract_eagle3_features(ubatch);
}
ret = GGML_STATUS_SUCCESS;
return res;
@ -858,7 +896,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd_inp();
// EAGLE3: use 3*target_hidden_size for concatenated features input
const int64_t n_embd = (model.arch == LLM_ARCH_EAGLE3 && batch_inp.embd) ? 3 * hparams.eagle3_target_hidden_size : hparams.n_embd;
const int64_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
@ -941,8 +980,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
// extract token embeddings
GGML_ASSERT(embd != nullptr);
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
if (model.arch == LLM_ARCH_EAGLE3) {
// g_embeddings are stored temporarily in embd buffer
const int64_t out_embd = hparams.n_embd;
GGML_ASSERT(n_tokens * out_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens * out_embd * sizeof(float));
} else {
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
@ -1181,7 +1227,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res->get_embd_pooled()) {
// For EAGLE3, don't override t_embd with t_embd_pooled - we need the prenorm value during eagle3 decoder autoregressive generation
if (t_embd && res->get_embd_pooled() && model.arch != LLM_ARCH_EAGLE3) {
t_embd = res->get_embd_pooled();
}
@ -1196,7 +1243,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
// EAGLE3: Map draft vocab to target vocab
if (model.arch == LLM_ARCH_EAGLE3 && model.d2t) {
static thread_local std::vector<int64_t> eagle3_d2t_map;
static thread_local std::vector<float> eagle3_draft_logits;
const int64_t draft_vocab_size = t_logits->ne[0];
const uint32_t last_idx = n_outputs - 1;
// Load d2t mapping once (on first call)
if (eagle3_d2t_map.empty()) {
eagle3_d2t_map.resize(model.d2t->ne[0]);
ggml_backend_tensor_get(model.d2t, eagle3_d2t_map.data(), 0, eagle3_d2t_map.size() * sizeof(int64_t));
}
// Read only the last token's draft logits
eagle3_draft_logits.resize(draft_vocab_size);
const size_t last_offset = last_idx * draft_vocab_size * sizeof(float);
ggml_backend_tensor_get(t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float));
// Map only the last token's draft logits to target vocab
float * last_logits_out = logits_out + last_idx * n_vocab;
std::fill(last_logits_out, last_logits_out + n_vocab, -std::numeric_limits<float>::infinity());
for (int64_t j = 0; j < draft_vocab_size; j++) {
const int64_t target_id = j + eagle3_d2t_map[j];
GGML_ASSERT(target_id >= 0 && target_id < n_vocab);
last_logits_out[target_id] = eagle3_draft_logits[j];
}
} else {
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
}
}
}
@ -1455,7 +1534,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
// EAGLE3: auto-detect encoder (embeddings+no target_model) or decoder (has target_model)
llm_graph_type gtype = LLM_GRAPH_TYPE_DEFAULT;
if (model.arch == LLM_ARCH_EAGLE3) {
if (cparams.embeddings && model.target_tok_embd == nullptr) {
gtype = LLM_GRAPH_TYPE_ENCODER;
} else if (model.target_tok_embd != nullptr) {
gtype = LLM_GRAPH_TYPE_DECODER;
}
}
const auto gparams = graph_params(res, ubatch, mctx, gtype);
res->reset();
@ -1491,6 +1579,7 @@ llm_graph_params llama_context::graph_params(
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.eagle3 =*/ &eagle3,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
@ -1534,6 +1623,27 @@ llm_graph_cb llama_context::graph_get_cb() const {
ggml_set_name(cur, name);
}
// EAGLE3: Extract intermediate layer features if this is an extraction point
if (cparams.eagle3_extract_enabled) {
static constexpr const char * prefix = "eagle3_extract_";
static constexpr size_t prefix_len = 15; // strlen("eagle3_extract_")
if (strncmp(name, prefix, prefix_len) == 0) {
// Parse the extraction index from the name (e.g., "eagle3_extract_0" -> 0)
size_t extract_idx = 0;
if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < eagle3.extract_tensors.size()) {
// Mark as output tensor to ensure proper backend assignment
ggml_set_output(cur);
// Store this tensor reference for post-execution extraction
eagle3.extract_tensors[extract_idx] = cur;
LLAMA_LOG_DEBUG("%s: EAGLE3 stored tensor reference for extraction: "
"index=%zu, layer=%d, target_layer=%d, tensor=%s\n",
__func__, extract_idx, il,
eagle3.extract_layer_indices[extract_idx], name);
}
}
}
if (!cparams.offload_kqv) {
if (strcmp(name, "kqv_merged_cont") == 0) {
// all nodes between the KV store and the attention output are run on the CPU
@ -1559,6 +1669,54 @@ llm_graph_cb llama_context::graph_get_cb() const {
};
}
void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_embd = model.hparams.n_embd;
const size_t n_layers = eagle3.extract_tensors.size();
// Allocate storage for concatenated features
const int64_t n_embd_concat = n_embd * n_layers;
eagle3.target_features.resize(n_embd_concat * n_tokens);
// Temporary buffer to hold layer features before transposing
static thread_local std::vector<float> temp_layer_features;
temp_layer_features.resize(n_embd * n_tokens);
LLAMA_LOG_DEBUG("%s: Start to extract EAGLE3 features: %zu layers, %lld tokens, %lld embd\n",
__func__, n_layers, (long long)n_tokens, (long long)n_embd);
// Extract each layer's features and interleave into token-major layout
for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) {
ggml_tensor * tensor = eagle3.extract_tensors[layer_idx];
GGML_ASSERT(tensor != nullptr && "EAGLE3 extraction tensor is null");
// Get the backend where this tensor is stored
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor);
GGML_ASSERT(backend != nullptr && "EAGLE3 tensor has no backend");
// Verify tensor shape: should be [n_embd, n_tokens]
GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens &&
"EAGLE3 extraction tensor has unexpected shape");
// Get layer features to temp buffer
const size_t size_bytes = n_embd * n_tokens * sizeof(float);
ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes);
ggml_backend_sched_synchronize(sched.get());
// Then copy to correct position in target_features
// target_features layout: [token_0_all_layers, token_1_all_layers, ...]
// Each token has [layer_0_embd, layer_1_embd, layer_2_embd]
for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) {
// Source: temp_layer_features[token_idx * n_embd ... (token_idx + 1) * n_embd - 1]
const float * src = temp_layer_features.data() + token_idx * n_embd;
// Dest: target_features[token_idx * n_embd_concat + layer_idx * n_embd]
float * dest = eagle3.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd;
std::memcpy(dest, src, n_embd * sizeof(float));
}
}
}
//
// state save/load
//
@ -2354,6 +2512,8 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.eagle3_model =*/ nullptr,
/*.target_model =*/ nullptr,
};
return result;
@ -2367,6 +2527,12 @@ llama_context * llama_init_from_model(
return nullptr;
}
// Auto-setup for EAGLE3: set target embedding if target_model is provided
if (model->arch == LLM_ARCH_EAGLE3 && params.target_model) {
model->target_tok_embd = params.target_model->tok_embd;
LLAMA_LOG_INFO("%s: EAGLE3 auto-setup: using target model's embedding layer\n", __func__);
}
if (params.n_batch == 0 && params.n_ubatch == 0) {
LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
return nullptr;
@ -3016,3 +3182,33 @@ void llama_opt_epoch(
callback_train,
callback_eval);
}
//
// EAGLE3 member functions
//
const float * llama_context::get_eagle3_target_features() const {
GGML_ASSERT(!eagle3.target_features.empty() && "EAGLE3 target features not extracted - call llama_encode() on target model first");
return eagle3.target_features.data();
}
void llama_context::set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens) {
GGML_ASSERT(g_embd != nullptr && "g_embeddings cannot be null");
GGML_ASSERT(n_embd > 0 && n_tokens > 0 && "invalid dimensions");
const size_t size = n_embd * n_tokens;
eagle3.g_embeddings.resize(size);
std::memcpy(eagle3.g_embeddings.data(), g_embd, size * sizeof(float));
}
//
// C API wrappers
//
const float * llama_get_eagle3_target_features(llama_context * ctx) {
return ctx->get_eagle3_target_features();
}
void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, int32_t n_embd, int32_t n_tokens) {
ctx->set_eagle3_g_embeddings(g_embd, n_embd, n_tokens);
}

View File

@ -208,6 +208,12 @@ public:
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
// EAGLE3: Get pointer to target model features extracted for EAGLE3 encoder
const float * get_eagle3_target_features() const;
// EAGLE3: Set g_embeddings from encoder output for decoder input
void set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens);
private:
llm_graph_params graph_params(
llm_graph_result * res,
@ -217,6 +223,9 @@ private:
llm_graph_cb graph_get_cb() const;
// EAGLE3: Extract intermediate layer features from target model
void extract_eagle3_features(const llama_ubatch & ubatch);
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
@ -235,6 +244,9 @@ private:
llama_adapter_loras loras;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
mutable llama_eagle3 eagle3; // EAGLE3 draft model support - stores features from target model
// mutable because it's modified during graph building (const function)
std::unique_ptr<llama_memory_i> memory;

View File

@ -34,6 +34,7 @@ struct llama_cparams {
bool warmup;
bool op_offload;
bool kv_unified;
bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding
enum llama_pooling_type pooling_type;

View File

@ -590,6 +590,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
loras (params.loras),
mctx (params.mctx),
cross (params.cross),
eagle3 (params.eagle3),
cb_func (params.cb),
res (params.res),
ctx0 (res->get_ctx()),

View File

@ -70,6 +70,30 @@ struct llama_cross {
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
// EAGLE3 support - stores intermediate features from target model
struct llama_eagle3 {
// Configuration: which layers to extract from target model
std::vector<int> extract_layer_indices;
// Extracted features from target model (for encoder input)
// Concatenated [layer_l, layer_m, layer_h] embeddings
// Shape: [n_layers * n_embd, n_tokens] where n_layers = extract_layer_indices.size()
std::vector<float> target_features;
// Encoder output (for decoder input)
std::vector<float> g_embeddings;
// Tensor references for feature extraction from target model
std::vector<ggml_tensor *> extract_tensors;
// Clear all stored data
void clear() {
target_features.clear();
g_embeddings.clear();
extract_tensors.clear();
}
};
struct llm_graph_params;
//
@ -416,6 +440,7 @@ struct llm_graph_params {
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
llama_eagle3 * eagle3; // non-const: we write extracted features here
uint32_t n_outputs;
@ -579,6 +604,7 @@ struct llm_graph_context {
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
llama_eagle3 * eagle3; // non-const: we write extracted features here
const llm_graph_cb & cb_func;

View File

@ -188,6 +188,13 @@ struct llama_hparams {
// qwen3vl deepstack
uint32_t n_deepstack_layers = 0;
// EAGLE3 draft model - layer indices to extract from target model
// e.g., for 32-layer target: [2, 16, 29] (low, middle, high)
std::array<int, 3> eagle3_extract_layers = {0, 0, 0};
// EAGLE3 draft model - target model hidden size
uint32_t eagle3_target_hidden_size = 0;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

View File

@ -2230,6 +2230,28 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_EAGLE3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// EAGLE3 layer extraction configuration
// Use array<int, 4> (has template instantiation), then copy first 3 elements
std::array<int, 4> extract_layers_tmp = {};
if (!ml.get_key_or_arr(LLM_KV_EAGLE3_EXTRACT_LAYERS, extract_layers_tmp, 3, false)) {
throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata");
}
std::copy_n(extract_layers_tmp.begin(), 3, hparams.eagle3_extract_layers.begin());
LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__,
hparams.eagle3_extract_layers[0],
hparams.eagle3_extract_layers[1],
hparams.eagle3_extract_layers[2]);
// EAGLE3 target model hidden size
ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.eagle3_target_hidden_size);
LLAMA_LOG_INFO("%s: EAGLE3 target_hidden_size = %u (draft n_embd = %u)\n", __func__,
hparams.eagle3_target_hidden_size, hparams.n_embd);
type = LLM_TYPE_UNKNOWN;
} break;
case LLM_ARCH_COGVLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -6408,6 +6430,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
}
} break;
case LLM_ARCH_EAGLE3:
{
const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size;
const int64_t n_embd_attn_input = 2 * n_embd;
// Get vocab size from the d2t tensor in the GGUF file
// d2t: draft to target mapping (size = draft_vocab_size)
const struct ggml_tensor * d2t_meta = ml.get_tensor_meta("d2t");
if (!d2t_meta) {
throw std::runtime_error("EAGLE3 model requires 'd2t' tensor but it was not found in the model file");
}
const int64_t n_draft_vocab = d2t_meta->ne[0];
// Feature fusion layer: projects 3 target layers to draft hidden size
fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_target_features, n_embd}, 0);
// Draft to target vocabulary mapping tensor
d2t = create_tensor(tn(LLM_TENSOR_EAGLE3_D2T), {n_draft_vocab}, 0);
// Output layer (uses draft vocab size)
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, 0);
// Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own)
const struct ggml_tensor * tok_embd_meta = ml.get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str());
if (tok_embd_meta) {
const int64_t n_target_vocab = tok_embd_meta->ne[1];
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0);
LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab);
}
// Single decoder layer
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// input_layernorm: applied to token embeddings
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
// Attention takes input_embeds_normed + fused_target_normed as input
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
// EAGLE-3 specific: hidden_norm applied to fused target features
layer.eagle3_hidden_norm = create_tensor(tn(LLM_TENSOR_EAGLE3_HIDDEN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
// rope_freqs for llama3 rope scaling (optional - only if EAGLE3 config has rope_scaling)
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_COGVLM:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -7564,6 +7642,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_minimax_m2>(*this, params);
} break;
case LLM_ARCH_EAGLE3:
{
if (params.gtype == LLM_GRAPH_TYPE_ENCODER) {
llm = std::make_unique<llm_build_eagle3_encode>(*this, params);
} else {
llm = std::make_unique<llm_build_eagle3_decode>(*this, params);
}
} break;
case LLM_ARCH_COGVLM:
{
llm = std::make_unique<llm_build_cogvlm>(*this, params);
@ -7749,6 +7835,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_EAGLE3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2

View File

@ -404,6 +404,9 @@ struct llama_layer {
struct ggml_tensor * ffn_act_beta = nullptr;
struct ggml_tensor * ffn_act_eps = nullptr;
// eagle3
struct ggml_tensor * eagle3_hidden_norm = nullptr;
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
@ -453,6 +456,13 @@ struct llama_model {
struct ggml_tensor * per_layer_model_proj = nullptr;
struct ggml_tensor * per_layer_proj_norm = nullptr;
// eagle3
struct ggml_tensor * fc = nullptr; // feature fusion layer
struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping
// Reference to target model's embedding layer
// This allows EAGLE3 to use target model's embeddings without copying
struct ggml_tensor * target_tok_embd = nullptr;
std::vector<llama_layer> layers;
//Dense linear projections for SentenceTransformers models like embeddinggemma

187
src/models/eagle3.cpp Normal file
View File

@ -0,0 +1,187 @@
#include "models.h"
// EAGLE3 Encoder: processes target model features through feature fusion layer
// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high
// Output: g_embeddings e.g. [4096, n_tokens] stored in context
llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size;
ggml_tensor * cur;
// Input: Target model features (3 layers concatenated: low, mid, high)
// Data will be provided via ubatch->embd in encode_eagle3_features()
auto inp_target = std::make_unique<llm_graph_input_embd>();
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens);
ggml_set_input(inp_target->embd);
ggml_tensor * target_features = inp_target->embd;
res->add_input(std::move(inp_target));
cb(target_features, "inp_target_features", -1);
// Feature fusion layer
ggml_tensor * fused_target = build_lora_mm(model.fc, target_features);
cb(fused_target, "fc_out", -1);
// Output: g_embeddings e.g. [4096, n_tokens]
cur = fused_target;
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}
// EAGLE3 Decoder: processes draft tokens using g_embeddings from encoder
// Input: draft tokens + g_embeddings from encoder
// Output: draft logits
llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_layer == 1); // EAGLE-3 has only one decoder layer
ggml_tensor * cur;
ggml_tensor * inpL;
// EAGLE3 Decoder receives:
// 1. Token embeddings (e.g.from EAGLE3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B)
// 2. g_embeddings from encoder
// Choose token_embd_eagle3: prefer EAGLE3's own if available (Llama 3.3 70B), else use target's (Llama 3.1 8B)
ggml_tensor * token_embd_eagle3 = (model.tok_embd != nullptr) ? model.tok_embd : model.target_tok_embd;
GGML_ASSERT(token_embd_eagle3 != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)");
ggml_tensor * input_embeds = build_inp_embd(token_embd_eagle3);
cb(input_embeds, "token_embd_eagle3", -1);
ggml_tensor * g_embeddings = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_set_input(g_embeddings);
ggml_set_name(g_embeddings, "inp_g_embeddings");
cb(g_embeddings, "inp_g_embeddings", -1);
// Store raw g_embeddings as residual
ggml_tensor * residual = g_embeddings;
// Apply input_layernorm to the token embeddings
ggml_tensor * input_embeds_normed = build_norm(input_embeds,
model.layers[0].attn_norm, NULL,
LLM_NORM_RMS, 0);
cb(input_embeds_normed, "input_layernorm", -1);
// Force a sync point between the two parallel RMS_NORM paths
// This prevents buffer reuse issues on GPU (EAGLE3 GPU fix)
ggml_set_sync(input_embeds_normed);
// Apply hidden_norm to g_embeddings
ggml_tensor * g_embeddings_normed = build_norm(g_embeddings,
model.layers[0].eagle3_hidden_norm, NULL,
LLM_NORM_RMS, -1);
cb(g_embeddings_normed, "g_embeddings_normed", -1);
// Concatenate normalized input_embeds and normalized g_embeddings
cur = ggml_concat(ctx0, input_embeds_normed, g_embeddings_normed, 0);
cb(cur, "concat_embeds_g", -1);
inpL = cur;
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
// Single decoder layer (il = 0)
const int il = 0;
{
// inpL is the concatenated input (normalized input_embeds + normalized g_embeddings)
ggml_tensor * inpSA = inpL;
// Self-attention with concatenated input
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, inpL);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, inpL);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, inpL);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// rope freq factors, returns nullptr if not available
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// RoPE
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur_rope", il);
cb(Kcur, "Kcur_rope", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
if (inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
}
// Add residual and update it
ggml_tensor * attn_with_residual = ggml_add(ctx0, cur, residual);
cb(attn_with_residual, "attn_with_residual", il);
// Update residual
residual = attn_with_residual;
// Apply FFN norm to the sum
ggml_tensor * ffn_inp = build_norm(attn_with_residual,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(ffn_inp, "post_attn_norm", il);
cur = ffn_inp;
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
inpL = cur;
}
cur = inpL;
// Output norm with residual
ggml_tensor * final_with_residual = ggml_add(ctx0, cur, residual);
cb(final_with_residual, "eagle3_prenorm", -1);
// Output prenorm state (for next token's g_embeddings in autoregressive generation)
ggml_set_output(final_with_residual);
res->t_embd = final_with_residual;
cur = build_norm(final_with_residual,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
// lm_head - projects to draft vocabulary
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@ -23,6 +23,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// EAGLE3: Extract intermediate layer features from target model at layer INPUT
if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) {
static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"};
for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) {
if (eagle3->extract_layer_indices[i] == il) {
cb(inpL, eagle3_extract_names[i], il);
break;
}
}
}
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,

View File

@ -150,6 +150,14 @@ struct llm_build_dream : public llm_graph_context {
llm_build_dream(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_eagle3_encode : public llm_graph_context {
llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_eagle3_decode : public llm_graph_context {
llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_ernie4_5 : public llm_graph_context {
llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params);
};