#include "common.h" #include "llama.h" #include #include #include #include /* def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10): input_length = input_ids.size(1) for ngram_size in range(max_ngram_size, 0, -1): # Extract the last n tokens as our search ngram ngram = input_ids[0, -ngram_size:].tolist() # Create sliding windows of size ngram_size windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) # Convert ngram to a tensor for comparison ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0) # Find where the windows match the ngram matches = (windows == ngram_tensor).all(dim=2) # Get the indices of matches match_indices = matches.nonzero(as_tuple=True)[1] # Iterate through match indices to find a valid continuation for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + num_pred_tokens # Ensure we don't go beyond the length of input_ids and avoid self-match if end_idx <= input_length and start_idx < input_length - ngram_size: return input_ids[0, start_idx:end_idx] # If no match is found, return an empty tensor return torch.tensor([], dtype=torch.long, device=input_ids.device) */ int main(int argc, char ** argv){ gpt_params params; if(gpt_params_parse(argc, argv, params) == false){ return 1; } // maximum n-grams to search for in prompt const int max_ngram_size = 3; // length of the candidate sequence, if match is found const int num_pred_tokens = 10; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("lookup", "log")); LOG_TEE("Log start\n"); log_dump_cmdline(argc, argv); #endif // LOG_DISABLE_LOGS // init llama.cpp llama_backend_init(params.numa); llama_model * model = NULL; llama_context * ctx = NULL; // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); // tokenize the prompt const bool add_bos = llama_should_add_bos_token(model); LOG("add_bos tgt: %d\n", add_bos); std::vector inp; inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); const int max_context_size = llama_n_ctx(ctx); const int max_tokens_list_size = max_context_size - 4; if ((int) inp.size() > max_tokens_list_size) { fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); return 1; } fprintf(stderr, "\n\n"); for (auto id : inp) { fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); } fflush(stderr); const int n_input = inp.size(); const auto t_enc_start = ggml_time_us(); llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); const auto t_enc_end = ggml_time_us(); int n_accept = 0; int n_past = inp.size(); bool has_eos = false; struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); const auto t_dec_start = ggml_time_us(); }