153 lines
4.8 KiB
C++
153 lines
4.8 KiB
C++
// Parallel Lazy Beam Search for llama.cpp
|
|
// Optimized for encoder-decoder models (NLLB, T5, etc.)
|
|
|
|
#pragma once
|
|
|
|
#include "llama.h"
|
|
#include <vector>
|
|
#include <functional>
|
|
|
|
namespace llama_beam {
|
|
|
|
// Configuration for beam search
|
|
struct beam_search_params {
|
|
int beam_size = 5; // Number of beams to maintain
|
|
float length_penalty_alpha = 1.0f; // Length penalty (1.0 = neutral, >1.0 = favor longer)
|
|
int max_length = 200; // Maximum tokens to generate
|
|
bool early_stopping = true; // Stop when all beams finish
|
|
int min_length = 1; // Minimum tokens to generate
|
|
float diversity_penalty = 0.0f; // Diversity penalty (0.0 = disabled)
|
|
|
|
// Advanced options
|
|
int top_k_per_beam = 0; // Top-K candidates per beam (0 = use all)
|
|
float score_threshold = -1e9f; // Minimum score threshold
|
|
bool normalize_scores = true; // Normalize scores by length
|
|
};
|
|
|
|
// Single beam hypothesis
|
|
struct beam_hypothesis {
|
|
std::vector<llama_token> tokens; // Generated tokens
|
|
float score; // Cumulative log probability
|
|
float normalized_score; // Score / length^alpha
|
|
llama_seq_id seq_id; // Sequence ID in KV cache
|
|
bool finished; // Has this beam finished (EOS)?
|
|
|
|
beam_hypothesis()
|
|
: score(0.0f), normalized_score(0.0f), seq_id(-1), finished(false) {}
|
|
};
|
|
|
|
// Candidate during expansion (before pruning)
|
|
struct beam_candidate {
|
|
beam_hypothesis hyp; // The hypothesis
|
|
int parent_beam_idx; // Which beam it came from
|
|
llama_seq_id parent_seq_id; // Parent's seq_id
|
|
llama_token last_token; // Token that was just added
|
|
float token_log_prob; // Log prob of last token
|
|
|
|
beam_candidate()
|
|
: parent_beam_idx(-1), parent_seq_id(-1), last_token(-1), token_log_prob(0.0f) {}
|
|
};
|
|
|
|
// Result of beam search
|
|
struct beam_search_result {
|
|
std::vector<beam_hypothesis> hypotheses; // All final hypotheses (sorted by score)
|
|
int n_steps; // Number of decode steps taken
|
|
bool stopped_early; // Did we hit early stopping?
|
|
|
|
// Get best hypothesis
|
|
const beam_hypothesis & best() const {
|
|
return hypotheses.empty() ?
|
|
*(beam_hypothesis*)nullptr : hypotheses[0];
|
|
}
|
|
};
|
|
|
|
// Main beam search engine
|
|
class beam_search_engine {
|
|
public:
|
|
// Constructor
|
|
beam_search_engine(
|
|
llama_context * ctx,
|
|
const beam_search_params & params
|
|
);
|
|
|
|
// Destructor
|
|
~beam_search_engine();
|
|
|
|
// Run beam search
|
|
// initial_tokens: Starting tokens (e.g., [EOS, target_lang])
|
|
// is_eos: Function to check if token is EOS
|
|
beam_search_result search(
|
|
const std::vector<llama_token> & initial_tokens,
|
|
std::function<bool(llama_token)> is_eos
|
|
);
|
|
|
|
// Step-by-step interface (for advanced control)
|
|
void initialize(const std::vector<llama_token> & initial_tokens);
|
|
bool step(std::function<bool(llama_token)> is_eos); // Returns false when done
|
|
beam_search_result get_results();
|
|
|
|
// Callbacks for monitoring
|
|
using step_callback_t = std::function<void(int step, const std::vector<beam_hypothesis>&)>;
|
|
void set_step_callback(step_callback_t callback);
|
|
|
|
private:
|
|
llama_context * ctx_;
|
|
beam_search_params params_;
|
|
|
|
std::vector<beam_hypothesis> beams_;
|
|
std::vector<beam_candidate> candidates_;
|
|
|
|
int current_step_;
|
|
bool initialized_;
|
|
|
|
step_callback_t step_callback_;
|
|
|
|
// Internal methods
|
|
void expand_beams(std::function<bool(llama_token)> is_eos);
|
|
void prune_candidates();
|
|
void rearrange_kv_caches();
|
|
float compute_score(const beam_hypothesis & hyp) const;
|
|
float apply_length_penalty(float score, int length) const;
|
|
|
|
// Helper to get top-K tokens from logits
|
|
std::vector<std::pair<llama_token, float>> get_top_k_tokens(
|
|
const float * logits,
|
|
int n_vocab,
|
|
int k
|
|
) const;
|
|
};
|
|
|
|
// Utility functions
|
|
|
|
// Default EOS checker
|
|
inline bool is_eos_token(llama_token token, const llama_vocab * vocab) {
|
|
return llama_vocab_is_eog(vocab, token);
|
|
}
|
|
|
|
// Print hypothesis for debugging
|
|
void print_hypothesis(
|
|
const beam_hypothesis & hyp,
|
|
const llama_vocab * vocab,
|
|
const char * prefix = ""
|
|
);
|
|
|
|
// Compare hypotheses by score (for sorting)
|
|
inline bool compare_hypotheses_by_score(
|
|
const beam_hypothesis & a,
|
|
const beam_hypothesis & b
|
|
) {
|
|
return a.normalized_score > b.normalized_score;
|
|
}
|
|
|
|
inline bool compare_candidates_by_score(
|
|
const beam_candidate & a,
|
|
const beam_candidate & b
|
|
) {
|
|
return a.hyp.normalized_score > b.hyp.normalized_score;
|
|
}
|
|
|
|
} // namespace llama_beam
|
|
|
|
|
|
|