Merge ad1b60abc4 into 58062860af
This commit is contained in:
commit
e4f3cc7782
|
|
@ -1671,6 +1671,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--backend-sampling"},
|
||||
"enable backend sampling (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.sampling.backend_sampling = true;
|
||||
}
|
||||
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
|
||||
add_opt(common_arg(
|
||||
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||
"pooling type for embeddings, use model default if unspecified",
|
||||
|
|
|
|||
|
|
@ -1084,6 +1084,7 @@ struct common_init_result::impl {
|
|||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
|
|
@ -1141,10 +1142,19 @@ common_init_result::common_init_result(common_params & params) :
|
|||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
// init the backend samplers as part of the context creation
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
|
|
|
|||
|
|
@ -216,6 +216,8 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ static llama_sampler_i llama_sampler_llg_i = {
|
|||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
/* .backend_init = */ NULL,
|
||||
/* .backend_accept = */ NULL,
|
||||
/* .backend_apply = */ NULL,
|
||||
/* .backend_set_input = */ NULL,
|
||||
};
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
|
|
|
|||
|
|
@ -121,18 +121,35 @@ struct common_sampler {
|
|||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
||||
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
||||
cur.resize(sampled_probs_count);
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
cur.resize(sampled_logits_count);
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.resize(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
|
|
@ -421,6 +438,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
|
|||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
std::vector<llama_sampler_seq_config> sampler_configs;
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
|
@ -78,7 +78,13 @@ int main(int argc, char ** argv) {
|
|||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
|
||||
samplers.push_back(smpl);
|
||||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = sampler_configs.size();
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
|
@ -180,7 +186,7 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
|
||||
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
|
||||
|
|
@ -236,15 +242,15 @@ int main(int argc, char ** argv) {
|
|||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_sampler_print(samplers[0]);
|
||||
llama_perf_sampler_print(sampler_configs[0].sampler);
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
for (auto & sampler_config : samplers) {
|
||||
llama_sampler_free(sampler_config);
|
||||
for (auto & sampler_config : sampler_configs) {
|
||||
llama_sampler_free(sampler_config.sampler);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
|||
|
|
@ -40,6 +40,20 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
enable_language(CUDA)
|
||||
|
||||
# Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
CCCL
|
||||
GIT_REPOSITORY https://github.com/nvidia/cccl.git
|
||||
GIT_TAG v3.2.0-rc1
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(CCCL)
|
||||
endif()
|
||||
|
||||
file(GLOB GGML_HEADERS_CUDA "*.cuh")
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
|
|
@ -102,6 +116,9 @@ if (CUDAToolkit_FOUND)
|
|||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
|
||||
else ()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
|
|
@ -109,6 +126,9 @@ if (CUDAToolkit_FOUND)
|
|||
endif()
|
||||
endif()
|
||||
else()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
|
||||
endif()
|
||||
|
||||
|
|
@ -177,6 +197,10 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
if (NOT MSVC)
|
||||
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
|
||||
else()
|
||||
# CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
|
||||
# https://github.com/NVIDIA/cccl/pull/6827
|
||||
list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
|
||||
endif()
|
||||
|
||||
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
|
|||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
|
|
@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
size_t temp_storage_bytes = 0;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||
stream);
|
||||
d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||
sizeof(float) * 8, stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||
stream);
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
0, sizeof(float) * 8, stream);
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
|
@ -141,7 +162,7 @@ static int next_power_of_2(int x) {
|
|||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -924,6 +924,7 @@ struct ggml_cuda_device_info {
|
|||
size_t vmm_granularity; // granularity of virtual memory
|
||||
size_t total_vram;
|
||||
int warp_size; // Number of threads in a dispatch
|
||||
bool supports_cooperative_launch; // whether cooperative launch is supported
|
||||
};
|
||||
|
||||
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
|
||||
|
|
|
|||
|
|
@ -149,9 +149,34 @@ static __global__ void cumsum_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
template <typename T>
|
||||
static void cumsum_cub(ggml_cuda_pool & pool,
|
||||
const T * src,
|
||||
T * dst,
|
||||
int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
size_t tmp_size = 0;
|
||||
|
||||
// Query how much temp storage CUDA UnBound (CUB) needs
|
||||
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
|
||||
tmp_size, // reference to size (will be set by CUB)
|
||||
src, // input pointer
|
||||
dst, // output pointer
|
||||
ne, // number of elements
|
||||
stream // CUDA stream to use
|
||||
);
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||
|
||||
// Perform the inclusive scan
|
||||
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
template<typename T>
|
||||
static void cumsum_cuda(
|
||||
const T * src, T * dst,
|
||||
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
||||
|
|
@ -165,6 +190,15 @@ static void cumsum_cuda(
|
|||
|
||||
if (is_contiguous) {
|
||||
use_cub = true;
|
||||
const int64_t nrows = ne01 * ne02 * ne03;
|
||||
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
|
||||
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
|
||||
for (int i=0; i<nrows; i++) {
|
||||
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
dim3 grid_dims(ne01, ne02, ne03);
|
||||
|
|
@ -203,7 +237,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
case GGML_TYPE_F32:
|
||||
{
|
||||
cumsum_cuda(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
ctx, (const float *)src0->data, (float *)dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include "ggml-cuda/count-equal.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/diag.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
|
|
@ -44,6 +45,7 @@
|
|||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/top-k.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/topk-moe.cuh"
|
||||
|
|
@ -241,6 +243,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
int supports_coop_launch = 0;
|
||||
CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
|
||||
info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
|
||||
#else
|
||||
info.devices[id].supports_cooperative_launch = false;
|
||||
#endif // !(GGML_USE_MUSA)
|
||||
#if defined(GGML_USE_HIP)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
|
||||
|
|
@ -2687,6 +2697,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SUM:
|
||||
ggml_cuda_op_sum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cuda_op_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2699,6 +2712,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SSM_SCAN:
|
||||
ggml_cuda_op_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TOP_K:
|
||||
ggml_cuda_op_top_k(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
|
|
@ -2708,9 +2724,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_cuda_op_cumsum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_cuda_op_tri(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4600,6 +4613,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return true;
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_TOP_K:
|
||||
case GGML_OP_ARGSORT:
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,14 @@
|
|||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "softmax.cuh"
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#else
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
|
|
@ -160,6 +168,156 @@ static __global__ void soft_max_f32(
|
|||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
|
||||
static __device__ float two_stage_warp_reduce_max(float val) {
|
||||
val = warp_reduce_max(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
|
||||
__shared__ float local_vals[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
local_vals[warp_id] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
val = -INFINITY;
|
||||
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
|
||||
val = local_vals[lane_id];
|
||||
}
|
||||
return warp_reduce_max(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ float two_stage_warp_reduce_sum(float val) {
|
||||
val = warp_reduce_sum(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
|
||||
__shared__ float local_vals[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
local_vals[warp_id] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
val = 0.0f;
|
||||
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
|
||||
val = local_vals[lane_id];
|
||||
}
|
||||
return warp_reduce_sum(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Template to allow keeping ncols in registers if they fit
|
||||
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p) {
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
const cg::grid_group g = cg::this_grid();
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int col_start = blockIdx.x * blockDim.x + tid;
|
||||
const int n_elem_per_thread = 4;
|
||||
|
||||
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
|
||||
float local_max = -INFINITY;
|
||||
const int step_size = gridDim.x * blockDim.x;
|
||||
|
||||
// Compute thread-local max
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
local_max = fmaxf(local_max, local_vals[i]);
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
|
||||
// Compute CTA-level max
|
||||
local_max = two_stage_warp_reduce_max(local_max);
|
||||
|
||||
// Store CTA-level max to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_maxs[blockIdx.x] = local_max;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute compute global max from CTA-level maxs
|
||||
assert(gridDim.x < blockDim.x); // currently we only support this case
|
||||
if (tid < gridDim.x) {
|
||||
local_max = tmp_maxs[tid];
|
||||
} else {
|
||||
local_max = -INFINITY;
|
||||
}
|
||||
local_max = two_stage_warp_reduce_max(local_max);
|
||||
|
||||
// Compute softmax dividends, accumulate divisor
|
||||
float tmp_expf = 0.0f;
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
if (idx < p.ncols) {
|
||||
const float tmp = expf(local_vals[i] - local_max);
|
||||
tmp_expf += tmp;
|
||||
dst[idx] = tmp;
|
||||
}
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
|
||||
// Reduce divisor within CTA
|
||||
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
|
||||
|
||||
// Store CTA-level sum to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_sums[blockIdx.x] = tmp_expf;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute global sum from CTA-level sums
|
||||
if (tid < gridDim.x) {
|
||||
tmp_expf = tmp_sums[tid];
|
||||
} else {
|
||||
tmp_expf = 0.0f;
|
||||
}
|
||||
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
|
||||
|
||||
// Divide dividend by global sum + store data
|
||||
for (int col = col_start; col < p.ncols;) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_elem_per_thread; i++) {
|
||||
const int idx = col + i * step_size;
|
||||
if (idx < p.ncols) {
|
||||
dst[idx] = local_vals[i] / tmp_expf;
|
||||
}
|
||||
}
|
||||
col += step_size * n_elem_per_thread;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
|
@ -216,9 +374,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
|
|||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||
}
|
||||
|
||||
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p)
|
||||
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
|
||||
// currently only support synchronizing the complete grid if not launched as a cluster group
|
||||
// (which requires CC > 9.0)
|
||||
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
|
||||
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
|
||||
{
|
||||
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
|
||||
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
|
||||
tmp_sums, p);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
template <typename T>
|
||||
static void soft_max_f32_cuda(const float * x,
|
||||
const T * mask,
|
||||
const float * sinks,
|
||||
float * dst,
|
||||
const soft_max_params & params,
|
||||
cudaStream_t stream,
|
||||
[[maybe_unused]] ggml_backend_cuda_context & ctx) {
|
||||
int nth = WARP_SIZE;
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
|
|
@ -236,8 +416,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin
|
|||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
// Parallelize across SMs for top-p/dist-sampling
|
||||
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
|
||||
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
|
||||
if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
|
||||
ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
|
||||
params.scale == 1.0f && params.max_bias == 0.0f) {
|
||||
ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
|
||||
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
|
||||
(void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(¶ms) };
|
||||
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
|
||||
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
|
||||
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
|
||||
soft_max_f32<false, 0, 0>
|
||||
<<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -315,9 +512,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
#include "argsort.cuh"
|
||||
#include "top-k.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
static __global__ void init_indices(int * indices, const int ncols) {
|
||||
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < ncols) {
|
||||
indices[col] = col;
|
||||
}
|
||||
}
|
||||
|
||||
static void top_k_cub(ggml_cuda_pool & pool,
|
||||
const float * src,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int k,
|
||||
cudaStream_t stream) {
|
||||
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
|
||||
cuda::execution::output_ordering::unsorted);
|
||||
auto stream_env = cuda::stream_ref{ stream };
|
||||
auto env = cuda::std::execution::env{ stream_env, requirements };
|
||||
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, 1);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, src, ncols * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
|
||||
}
|
||||
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
#endif // CUB_TOP_K_AVAILABLE
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
int * dst_d = (int *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// are these asserts truly necessary?
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t k = dst->ne[0];
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
|
||||
// https://github.com/NVIDIA/cccl/issues/6391
|
||||
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
|
||||
for (int i = 0; i < nrows; i++) {
|
||||
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
|
||||
}
|
||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||
// Fall back to argsort + copy
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
||||
int * tmp_dst = temp_dst_alloc.get();
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#else // GGML_CUDA_USE_CUB
|
||||
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
||||
int * tmp_dst = temp_dst_alloc.get();
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#endif
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -45,9 +45,11 @@
|
|||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
|
@ -70,6 +72,7 @@
|
|||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@
|
|||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
|
|
|
|||
|
|
@ -316,6 +316,11 @@ extern "C" {
|
|||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
};
|
||||
|
||||
struct llama_sampler_seq_config {
|
||||
llama_seq_id seq_id;
|
||||
struct llama_sampler * sampler;
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
// https://github.com/ggml-org/llama.cpp/pull/7544
|
||||
struct llama_context_params {
|
||||
|
|
@ -364,6 +369,11 @@ extern "C" {
|
|||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL]
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
|
@ -983,6 +993,32 @@ extern "C" {
|
|||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
// Get the backend sampled token for the ith token.
|
||||
// Returns LLAMA_TOKEN_NULL if no token was sampled.
|
||||
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled probabilites for the ith token
|
||||
// The index matches llama_get_sampled_token_ith().
|
||||
// Returns NULL if no probabilites were generated.
|
||||
LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i);
|
||||
//
|
||||
// Get the number of backend sampled probabilites for the ith token.
|
||||
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled logits for the ith token
|
||||
// Returns NULL if no logits were sampled.
|
||||
LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
//
|
||||
// Get the number of backend sampled logits for the ith token.
|
||||
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled candidates (token ids) for the ith token
|
||||
// Returns NULL if no candidates were sampled.
|
||||
LLAMA_API llama_token * llama_get_sampled_candidates_ith(struct llama_context * ctx, int32_t i);
|
||||
//
|
||||
// Get the number of backend sampled candidates for the ith token.
|
||||
LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
@ -1154,11 +1190,16 @@ extern "C" {
|
|||
//
|
||||
// llama_sampler_free(smpl);
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
//
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
||||
struct llama_sampler_data {
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * probs;
|
||||
struct ggml_tensor * sampled;
|
||||
struct ggml_tensor * candidates;
|
||||
};
|
||||
|
||||
// user code can implement the interface below in order to create custom llama_sampler
|
||||
struct llama_sampler_i {
|
||||
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
|
||||
|
|
@ -1168,17 +1209,40 @@ extern "C" {
|
|||
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||
// backend sampling interface:
|
||||
|
||||
// return true if the backend supports all ops needed by the sampler
|
||||
// note: call once per sampler
|
||||
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
||||
|
||||
// call after .backend_accept()
|
||||
void (*backend_accept)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
||||
// call after .backend_init()
|
||||
void (*backend_apply)(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data);
|
||||
|
||||
// call before .backend_apply()
|
||||
void (*backend_set_input)(struct llama_sampler * smpl);
|
||||
};
|
||||
|
||||
struct llama_sampler {
|
||||
const struct llama_sampler_i * iface;
|
||||
struct llama_sampler_i * iface;
|
||||
|
||||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
|
|
@ -1194,7 +1258,15 @@ extern "C" {
|
|||
|
||||
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// return NULL if:
|
||||
// - the sampler is NULL
|
||||
// - the sampler is not a llama_sampler_chain
|
||||
// - the index is out of bounds, unless i == -1
|
||||
// - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i);
|
||||
|
||||
// the total number of samplers in the chain
|
||||
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||
|
||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ bool llama_batch_allocr::init(
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
bool output_all,
|
||||
bool sampling) {
|
||||
clear();
|
||||
|
||||
batch = batch_inp;
|
||||
|
|
@ -145,6 +146,24 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
}
|
||||
|
||||
if (sampling) {
|
||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
seq_output_count[seq_id]++;
|
||||
if (seq_output_count[seq_id] > 1) {
|
||||
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// compute stats
|
||||
//
|
||||
|
|
|
|||
|
|
@ -81,7 +81,8 @@ public:
|
|||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
bool output_all,
|
||||
bool sampling = false);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,25 @@ llama_context::llama_context(
|
|||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
// Initialize backend samplers here so they are part of the sampling graph
|
||||
// before the reserve passes run later in this function. This avoids a later
|
||||
// re-reserve when graph nodes change.
|
||||
if (params.samplers != nullptr && params.n_samplers > 0) {
|
||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
|
||||
throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
|
||||
}
|
||||
|
||||
if (set_sampler(config.seq_id, config.sampler)) {
|
||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||
|
||||
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||
|
|
@ -231,7 +250,10 @@ llama_context::llama_context(
|
|||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
// Create a dummy batch for initialization.
|
||||
llama_batch dummy_batch = {};
|
||||
dummy_batch.n_tokens = 0;
|
||||
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
|
||||
throw std::runtime_error("failed to reserve initial output buffer");
|
||||
}
|
||||
|
||||
|
|
@ -456,6 +478,16 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the full vocabulary token ids for backend samplers.
|
||||
{
|
||||
const int n_vocab = model.vocab.n_tokens();
|
||||
|
||||
sampling.token_ids_full_vocab.resize(n_vocab);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sampling.token_ids_full_vocab[i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
|
|
@ -617,6 +649,35 @@ float * llama_context::get_logits() {
|
|||
return logits;
|
||||
}
|
||||
|
||||
int64_t llama_context::resolve_output_row(int32_t i) const {
|
||||
int64_t j = -1;
|
||||
|
||||
// support negative indices (last output row)
|
||||
if (i < 0) {
|
||||
j = n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
||||
} else {
|
||||
// use output_ids to translate the batch token index into a row number
|
||||
// that holds this token's data.
|
||||
j = output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
// the batch token was not configured to output anything
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
|
||||
if (j >= n_outputs) {
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
return j;
|
||||
}
|
||||
|
||||
float * llama_context::get_logits_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
|
|
@ -663,6 +724,10 @@ float * llama_context::get_embeddings() {
|
|||
return embd;
|
||||
}
|
||||
|
||||
llama_token * llama_context::get_sampled_tokens() {
|
||||
return sampling.sampled;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
|
|
@ -712,6 +777,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
return it->second.data();
|
||||
}
|
||||
|
||||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.sampled == nullptr) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
|
||||
return sampling.sampled[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
}
|
||||
|
||||
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.probs + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.logits + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if (sampling.candidates != nullptr &&
|
||||
(size_t) row < sampling.candidates_count.size() &&
|
||||
sampling.candidates_count[row] > 0) {
|
||||
return sampling.candidates + row*model.vocab.n_tokens();
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
// fallback to full vocab list
|
||||
}
|
||||
|
||||
return sampling.token_ids_full_vocab.data();
|
||||
}
|
||||
|
||||
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.candidates == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if ((size_t) row >= sampling.candidates_count.size()) {
|
||||
return 0;
|
||||
}
|
||||
return sampling.candidates_count[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
return model.vocab.n_tokens();
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if ((size_t) row >= sampling.logits_count.size()) {
|
||||
return 0;
|
||||
}
|
||||
return sampling.logits_count[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = resolve_output_row(idx);
|
||||
if ((size_t) row >= sampling.probs_count.size()) {
|
||||
return 0;
|
||||
}
|
||||
return sampling.probs_count[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void llama_context::attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) {
|
||||
|
|
@ -768,6 +963,42 @@ void llama_context::set_warmup(bool value) {
|
|||
cparams.warmup = value;
|
||||
}
|
||||
|
||||
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
const bool can_offload =
|
||||
sampler &&
|
||||
sampler->iface->backend_init &&
|
||||
sampler->iface->backend_apply &&
|
||||
llama_sampler_chain_n(sampler) > 0;
|
||||
|
||||
if (sampler && can_offload) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
|
||||
sampler->iface->backend_init(sampler, buft);
|
||||
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sampler && !can_offload) {
|
||||
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_context::set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale) {
|
||||
|
|
@ -908,7 +1139,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
n_queued_tokens += n_tokens;
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_tokens) < n_tokens) {
|
||||
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
return -2;
|
||||
};
|
||||
|
|
@ -1032,6 +1263,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||
std::map<llama_seq_id, uint32_t> seq_to_row;
|
||||
// how many output tokens we have seen so far for this ubatch.
|
||||
uint32_t local = 0;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
// skip tokens that are not output.
|
||||
if (!ubatch.output[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
// row_offset is the number of output tokens before this ubatch.
|
||||
seq_to_row[seq_id] = row_offset + local;
|
||||
++local;
|
||||
}
|
||||
return seq_to_row;
|
||||
}
|
||||
|
||||
static void copy_tensor_async_ints(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * sampled,
|
||||
size_t sampled_size,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (sampled == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < sampled_size);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
float * dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
float * row_ptr = dst + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of logits/probabilities that were written for this row.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_candidates(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
llama_token * row_ptr = dst + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of candidates that were written.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
|
|
@ -1053,8 +1390,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool has_samplers = !sampling.samplers.empty();
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||
output_all,
|
||||
has_samplers)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1135,7 +1476,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
||||
return -2;
|
||||
};
|
||||
|
|
@ -1200,6 +1541,28 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// This flag indicates whether a backend sampler has actually sampled a specific
|
||||
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
|
||||
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
|
||||
if (has_samplers && has_sampled) {
|
||||
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
||||
const auto stride = n_vocab;
|
||||
|
||||
// async copy the sampled tokens from the backend to the host.
|
||||
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
|
||||
|
||||
// async copy the sampled logits from the backend to the host.
|
||||
copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
||||
|
||||
// async copy the sampled probablities from the backend to the host.
|
||||
copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
||||
|
||||
// async copy the candidate token ids from the backend to the host.
|
||||
// These are needed by CPU samplers to map probability/logit indices to vocab token ids.
|
||||
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
||||
}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
|
|
@ -1208,7 +1571,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
// For multipsequence batches that mix backend samplers and CPU sampler
|
||||
// this is currently inefficient as we copy all logits even for the
|
||||
// backend sampled tokens.
|
||||
if (logits && t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
|
@ -1223,7 +1589,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
if (embd && t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
|
|
@ -1340,7 +1706,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// output
|
||||
//
|
||||
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
|
|
@ -1359,16 +1725,60 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
has_embd = true;
|
||||
}
|
||||
|
||||
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
// Check which sampling modes are needed by sequences in the current batch.
|
||||
bool batch_has_sampling = false;
|
||||
bool batch_needs_cpu_logits = false;
|
||||
|
||||
if (batch.logits) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
||||
batch_has_sampling = true;
|
||||
} else {
|
||||
batch_needs_cpu_logits = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When batch.logits is nullptr (when loading state with a dummy batch),
|
||||
// allocate CPU logits.
|
||||
batch_needs_cpu_logits = true;
|
||||
}
|
||||
|
||||
size_t backend_float_count = 0;
|
||||
size_t backend_token_count = 0;
|
||||
|
||||
// Allocate CPU logits buffer only if needed by sequences in this batch
|
||||
logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0;
|
||||
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
|
||||
if (!batch_has_sampling) {
|
||||
sampling.logits_size = 0;
|
||||
sampling.probs_size = 0;
|
||||
sampling.sampled_size = 0;
|
||||
sampling.candidates_size = 0;
|
||||
} else {
|
||||
sampling.logits_size = n_vocab*n_outputs_max;
|
||||
sampling.probs_size = n_vocab*n_outputs_max;
|
||||
sampling.sampled_size = n_outputs_max;
|
||||
sampling.candidates_size = n_vocab*n_outputs_max;
|
||||
|
||||
backend_float_count = sampling.logits_size + sampling.probs_size;
|
||||
backend_token_count = sampling.sampled_size + sampling.candidates_size;
|
||||
}
|
||||
|
||||
if (output_ids.empty()) {
|
||||
// init, never resized afterwards
|
||||
output_ids.resize(n_batch);
|
||||
}
|
||||
|
||||
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
||||
const size_t new_size = (logits_size + embd_size) * sizeof(float);
|
||||
const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float)
|
||||
+ backend_token_count * sizeof(llama_token);
|
||||
|
||||
// alloc only when more than the current capacity is required
|
||||
// TODO: also consider shrinking the buffer
|
||||
|
|
@ -1376,7 +1786,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
if (buf_output) {
|
||||
#ifndef NDEBUG
|
||||
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
||||
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
#endif
|
||||
synchronize();
|
||||
buf_output = nullptr;
|
||||
|
|
@ -1400,8 +1810,58 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
||||
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
logits = nullptr;
|
||||
embd = nullptr;
|
||||
|
||||
// reset sampling pointers.
|
||||
sampling.logits = nullptr;
|
||||
sampling.probs = nullptr;
|
||||
sampling.sampled = nullptr;
|
||||
sampling.candidates = nullptr;
|
||||
|
||||
size_t offset = 0;
|
||||
uint8_t * base = (uint8_t *) output_base;
|
||||
|
||||
logits = (has_logits && batch_needs_cpu_logits) ? output_base : nullptr;
|
||||
offset += logits_size * sizeof(float);
|
||||
|
||||
embd = has_embd ? (float *) (base + offset) : nullptr;
|
||||
offset += embd_size * sizeof(float);
|
||||
|
||||
if (batch_has_sampling) {
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
|
||||
sampling.probs = (float *) (base + offset);
|
||||
offset += sampling.probs_size * sizeof(float);
|
||||
|
||||
sampling.sampled = (llama_token *) (base + offset);
|
||||
offset += sampling.sampled_size * sizeof(llama_token);
|
||||
|
||||
sampling.candidates = (llama_token *) (base + offset);
|
||||
offset += sampling.candidates_size * sizeof(llama_token);
|
||||
|
||||
// The count vectors keep track of the actual number of logits/probs/candidates
|
||||
// copied from the backend for each output row.
|
||||
const size_t n_rows = (size_t) n_outputs_max;
|
||||
if (sampling.outputs_capacity < n_rows) {
|
||||
// The output size has increased, so resize and reset the count vectors.
|
||||
sampling.outputs_capacity = n_rows;
|
||||
|
||||
sampling.logits_count.assign(n_rows, 0);
|
||||
sampling.probs_count.assign(n_rows, 0);
|
||||
sampling.candidates_count.assign(n_rows, 0);
|
||||
} else {
|
||||
// The output size has not increased so just reset the counts to zero.
|
||||
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
|
||||
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
}
|
||||
|
||||
if (sampling.sampled) {
|
||||
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
||||
}
|
||||
}
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||
|
|
@ -1430,6 +1890,40 @@ void llama_context::output_reorder() {
|
|||
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.logits && sampling.logits_size > 0) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.probs && sampling.probs_size > 0) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.candidates && sampling.candidates_size > 0) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.sampled && sampling.sampled_size > 0) {
|
||||
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.logits_count.empty()) {
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.probs_count.empty()) {
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.candidates_count.empty()) {
|
||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||
}
|
||||
}
|
||||
|
||||
output_swaps.clear();
|
||||
|
|
@ -1476,6 +1970,15 @@ ggml_cgraph * llama_context::graph_reserve(
|
|||
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
|
||||
// set one output token per sequence in order to activate all backend samplers
|
||||
std::vector<llama_seq_id> seq_ids(n_seqs);
|
||||
for (uint32_t i = 0; i < n_seqs; ++i) {
|
||||
seq_ids[i] = i;
|
||||
ubatch.n_seq_id[i] = 1;
|
||||
ubatch.seq_id[i] = &seq_ids[i];
|
||||
ubatch.output[i] = true;
|
||||
}
|
||||
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
|
@ -1519,6 +2022,7 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
|
|
@ -2006,7 +2510,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|||
auto n_outputs = this->n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > output_reserve(n_outputs)) {
|
||||
// Create a dummy batch for state loading.
|
||||
llama_batch dummy_batch = {};
|
||||
dummy_batch.n_tokens = 0;
|
||||
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
|
|
@ -2248,7 +2755,7 @@ void llama_context::opt_epoch_iter(
|
|||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
|
@ -2393,6 +2900,8 @@ llama_context_params llama_context_default_params() {
|
|||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
@ -2552,7 +3061,15 @@ float * llama_get_logits(llama_context * ctx) {
|
|||
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_logits_ith(i);
|
||||
float * res = nullptr;
|
||||
|
||||
res = ctx->get_sampled_logits_ith(i);
|
||||
|
||||
if (!res) {
|
||||
res = ctx->get_logits_ith(i);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(llama_context * ctx) {
|
||||
|
|
@ -2573,6 +3090,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
return ctx->set_sampler(seq_id, smpl);
|
||||
}
|
||||
|
||||
llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_sampled_token_ith(i);
|
||||
}
|
||||
|
||||
float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_sampled_probs_ith(i);
|
||||
}
|
||||
|
||||
float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_sampled_logits_ith(i);
|
||||
}
|
||||
|
||||
llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
|
||||
}
|
||||
|
||||
uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
|
||||
}
|
||||
|
||||
uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
|
||||
}
|
||||
|
||||
uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
|
||||
}
|
||||
|
||||
// llama adapter API
|
||||
|
||||
int32_t llama_set_adapter_lora(
|
||||
|
|
|
|||
|
|
@ -70,6 +70,18 @@ struct llama_context {
|
|||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
llama_token * get_sampled_tokens();
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
|
||||
float * get_sampled_logits_ith(int32_t idx);
|
||||
size_t get_sampled_logits_count(int32_t idx);
|
||||
|
||||
float * get_sampled_probs_ith(int32_t idx);
|
||||
size_t get_sampled_probs_count(int32_t idx);
|
||||
|
||||
const llama_token * get_sampled_candidates_ith(int32_t idx);
|
||||
size_t get_sampled_candidates_count(int32_t idx);
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
|
@ -192,9 +204,10 @@ private:
|
|||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
uint32_t output_reserve(int32_t n_outputs);
|
||||
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
|
||||
|
||||
void output_reorder();
|
||||
int64_t resolve_output_row(int32_t i) const;
|
||||
|
||||
//
|
||||
// graph
|
||||
|
|
@ -213,6 +226,8 @@ public:
|
|||
ggml_cgraph * graph_reserve(
|
||||
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||
|
||||
bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
llm_graph_result * res,
|
||||
|
|
@ -247,6 +262,31 @@ private:
|
|||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
struct sampling_info {
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
||||
llama_token * sampled = nullptr;
|
||||
size_t sampled_size = 0;
|
||||
|
||||
float * probs = nullptr;
|
||||
size_t probs_size = 0;
|
||||
|
||||
llama_token * candidates = nullptr;
|
||||
size_t candidates_size = 0;
|
||||
|
||||
size_t outputs_capacity = 0;
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
std::vector<uint32_t> candidates_count;
|
||||
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
sampling_info sampling;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
|
|
@ -521,6 +522,43 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
// set the inputs only for the active samplers in the current ubatch
|
||||
std::unordered_set<llama_seq_id> active_samplers;
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
||||
if (ubatch->output[i]) {
|
||||
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
active_samplers.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto seq_id : active_samplers) {
|
||||
if (samplers.find(seq_id) == samplers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto & sampler = samplers[seq_id];
|
||||
|
||||
if (sampler->iface->backend_set_input) {
|
||||
sampler->iface->backend_set_input(sampler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||
if (samplers.size() != params.samplers.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||
if (samplers[seq_id] != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
|
|
@ -541,6 +579,10 @@ void llm_graph_result::reset() {
|
|||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_sampled.clear();
|
||||
t_sampled_probs.clear();
|
||||
t_sampled_logits.clear();
|
||||
t_candidates.clear();
|
||||
|
||||
params = {};
|
||||
|
||||
|
|
@ -565,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_result::set_outputs() {
|
||||
if (t_logits != nullptr) {
|
||||
ggml_set_output(t_logits);
|
||||
}
|
||||
if (t_embd != nullptr) {
|
||||
ggml_set_output(t_embd);
|
||||
}
|
||||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_probs) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_logits) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_candidates) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
||||
if (!this->params.allow_reuse(params)) {
|
||||
if (debug > 1) {
|
||||
|
|
@ -646,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
res (params.res),
|
||||
ctx0 (res->get_ctx()),
|
||||
|
|
@ -1834,8 +1909,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1848,8 +1925,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
|
|
@ -2086,6 +2165,86 @@ void llm_graph_context::build_pooling(
|
|||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
void llm_graph_context::build_sampling() const {
|
||||
if (samplers.empty() || !res->t_logits) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_logit_row[seq_id] = logit_row_idx;
|
||||
logit_row_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
||||
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
||||
|
||||
// add a dummy row of logits
|
||||
// this trick makes the graph static, regardless of which samplers are activated
|
||||
// this is important in order to minimize graph reallocations
|
||||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
|
||||
// inactive samplers alawys work on the first row
|
||||
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_data data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
|
||||
assert(sampler->iface->backend_apply);
|
||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||
/*
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
||||
ggml_tensor * selected_token = it->second;
|
||||
if (selected_token != nullptr) {
|
||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
|
|
@ -396,6 +397,18 @@ public:
|
|||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||
samplers(std::move(samplers)) { }
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
|
|
@ -429,6 +442,23 @@ struct llm_graph_params {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
static bool samplers_equal(
|
||||
const std::map<llama_seq_id, llama_sampler *> & lhs,
|
||||
const std::map<llama_seq_id, llama_sampler *> & rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto & [seq_id, sampler] : lhs) {
|
||||
auto it = rhs.find(seq_id);
|
||||
if (it == rhs.end() || it->second != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
llm_graph_cb cb;
|
||||
|
|
@ -468,6 +498,28 @@ struct llm_graph_params {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (n_outputs != other.n_outputs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!samplers_equal(samplers, other.samplers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (samplers.size() > 0) {
|
||||
if (!ubatch.data || !other.ubatch.data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check that the outputs are the same for all samplers
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (ubatch.output[i] != other.ubatch.output[i] ||
|
||||
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
|
|
@ -475,8 +527,7 @@ struct llm_graph_params {
|
|||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
n_outputs == other.n_outputs;
|
||||
cross == other.cross;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -499,6 +550,7 @@ public:
|
|||
void reset();
|
||||
|
||||
void set_inputs(const llama_ubatch * ubatch);
|
||||
void set_outputs();
|
||||
|
||||
// try to update the existing graph result using the new graph parameters in order to reuse it
|
||||
// this can only be done if we determine that the resulting graph using the new graph parameters
|
||||
|
|
@ -517,6 +569,11 @@ public:
|
|||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
|
@ -592,6 +649,8 @@ struct llm_graph_context {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
llm_graph_result * res;
|
||||
|
|
@ -832,6 +891,12 @@ struct llm_graph_context {
|
|||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
|
||||
//
|
||||
// sampling (backend sampling)
|
||||
//
|
||||
|
||||
void build_sampling() const;
|
||||
|
||||
//
|
||||
// dense (out)
|
||||
//
|
||||
|
|
|
|||
|
|
@ -7642,12 +7642,17 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
|
||||
// add backend sampling layers (if any)
|
||||
llm->build_sampling();
|
||||
|
||||
// if the gguf model was converted with --sentence-transformers-dense-modules
|
||||
// there will be two additional dense projection layers
|
||||
// dense linear projections are applied after pooling
|
||||
// TODO: move reranking logic here and generalize
|
||||
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
|
||||
|
||||
llm->res->set_outputs();
|
||||
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -14,7 +14,16 @@ struct llama_grammar;
|
|||
struct llama_sampler_chain {
|
||||
llama_sampler_chain_params params;
|
||||
|
||||
std::vector<struct llama_sampler *> samplers;
|
||||
// has .backend_init() been called?
|
||||
bool is_init = false;
|
||||
|
||||
struct info {
|
||||
bool is_backend;
|
||||
|
||||
llama_sampler * ptr;
|
||||
};
|
||||
|
||||
std::vector<info> samplers;
|
||||
|
||||
// timing
|
||||
|
||||
|
|
@ -29,4 +38,4 @@ struct llama_sampler * llama_sampler_init_dry_testing(
|
|||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const std::vector<std::vector<llama_token>>& seq_breakers);
|
||||
const std::vector<std::vector<llama_token>> & seq_breakers);
|
||||
|
|
|
|||
|
|
@ -222,6 +222,17 @@ llama_build_and_test(test-backend-ops.cpp)
|
|||
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
llama_build_and_test(test-autorelease.cpp LABEL "model")
|
||||
|
||||
llama_build_and_test(test-backend-sampler.cpp LABEL "model")
|
||||
target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler)
|
||||
|
||||
# Test for state restore with fragmented KV cache
|
||||
# Requires a model, uses same args pattern as test-thread-safety
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
|
|
|
|||
|
|
@ -7613,6 +7613,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
exponent <<= 1;
|
||||
}
|
||||
#endif
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
for (bool mask : {false, true}) {
|
||||
for (bool sinks : {false, true}) {
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
|
|
@ -7666,7 +7669,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (bool fw : {true, false}) { // fw == forward
|
||||
bool all = true;
|
||||
|
||||
|
|
@ -7841,6 +7843,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_xielu());
|
||||
|
||||
|
|
@ -8162,6 +8165,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
}
|
||||
|
||||
for (int col: {8192, 16384, 32768, 65536, 131072, 262144, 524288}) {
|
||||
for (int rows: {1, 4, 16}){
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||
|
||||
|
|
@ -8206,6 +8215,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
|
||||
for (auto k : {1, 10, 40, 400}) {
|
||||
|
|
@ -8216,13 +8227,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
}
|
||||
|
||||
for (auto nrows : {1, 4, 8, 16}) {
|
||||
for (auto cols : {128, 1024, 4096, 8192, 16384, 32768, 65536, 131072, 200000, 2000000}) {
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, {cols, nrows, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
|
||||
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
|
@ -1397,17 +1397,22 @@ json format_response_rerank(
|
|||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_logits = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.resize(n_logits);
|
||||
if (sampled_ids) {
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < n_logits; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
// sort tokens by logits
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
|
|
|
|||
|
|
@ -1056,6 +1056,25 @@ struct server_context_impl {
|
|||
return false;
|
||||
}
|
||||
|
||||
const bool need_logits = task.params.sampling.n_probs > 0;
|
||||
|
||||
bool backend_sampling = true;
|
||||
|
||||
backend_sampling &= task.params.sampling.backend_sampling;
|
||||
|
||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
|
||||
// TODO: tmp until backend sampling is fully implemented
|
||||
if (backend_sampling) {
|
||||
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
|
||||
} else {
|
||||
llama_set_sampler(ctx, slot.id, nullptr);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
|
|
@ -136,6 +137,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
|
|
@ -206,8 +208,12 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||
|
||||
printf("params.sampling.backend_sampling = %d\n", params.sampling.backend_sampling);
|
||||
printf("defaults.sampling.backend_sampling = %d\n", defaults.sampling.backend_sampling);
|
||||
|
||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
|
|
|
|||
|
|
@ -13,16 +13,16 @@ def create_server():
|
|||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||
[
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue|very teaful|very busy", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", True, None),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||
|
|
@ -54,8 +54,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
@pytest.mark.parametrize(
|
||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
|
||||
("Book", "What is the best book", 8, "(Suddenly)+|Timmy", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
|
|
@ -115,7 +115,7 @@ def test_chat_completion_with_openai_library():
|
|||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
assert match_regex("(Suddenly)+|Timmy", res.choices[0].message.content)
|
||||
|
||||
|
||||
def test_chat_template():
|
||||
|
|
@ -494,5 +494,5 @@ def test_chat_completions_multiple_choices():
|
|||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert match_regex("Suddenly|Timmy", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def create_server():
|
|||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
||||
("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
||||
])
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
||||
|
|
@ -42,7 +42,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
|||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
])
|
||||
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
|
|
@ -103,7 +103,7 @@ def test_completion_with_openai_library():
|
|||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].text is not None
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
assert match_regex("(going|bed)+|froze and every|froze and bri", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
|
|
@ -122,7 +122,7 @@ def test_completion_stream_with_openai_library():
|
|||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("(going|bed)+", output_text)
|
||||
assert match_regex("(going|bed)+|froze and every|froze and bri", output_text)
|
||||
|
||||
|
||||
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
|
||||
|
|
@ -146,7 +146,7 @@ def test_completion_stream_with_openai_library_stops():
|
|||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
|
||||
assert match_regex("Sure, here's one for[\\s\\S]*|Sure thing..Why don't|Sure! Here's one for you:", output_text), f'Unexpected output: {output_text}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
|
|
@ -511,8 +511,8 @@ def test_n_probs_post_sampling():
|
|||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
# at low temperature, one of the token has a very high probability
|
||||
assert any(prob["prob"] >= 0.99 for prob in tok["top_probs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
|
|
|
|||
|
|
@ -185,6 +185,11 @@
|
|||
key: 'samplers',
|
||||
label: 'Samplers',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'backend_sampling',
|
||||
label: 'Backend sampling',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
autoMicOnEmpty: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||
backend_sampling: false,
|
||||
temperature: 0.8,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
|
|
@ -57,6 +58,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
'When copying a message with text attachments, combine them into a single plain text string instead of a special format that can be pasted back as attachments.',
|
||||
samplers:
|
||||
'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature',
|
||||
backend_sampling:
|
||||
'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.',
|
||||
temperature:
|
||||
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
|
||||
dynatemp_range:
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ export class ChatService {
|
|||
dry_penalty_last_n,
|
||||
// Other parameters
|
||||
samplers,
|
||||
backend_sampling,
|
||||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
|
|
@ -158,6 +159,8 @@ export class ChatService {
|
|||
: samplers;
|
||||
}
|
||||
|
||||
if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling;
|
||||
|
||||
if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token;
|
||||
|
||||
if (custom) {
|
||||
|
|
|
|||
|
|
@ -1401,6 +1401,8 @@ class ChatStore {
|
|||
if (hasValue(currentConfig.dry_penalty_last_n))
|
||||
apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n);
|
||||
if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers;
|
||||
if (currentConfig.backend_sampling)
|
||||
apiOptions.backend_sampling = currentConfig.backend_sampling;
|
||||
if (currentConfig.custom) apiOptions.custom = currentConfig.custom;
|
||||
|
||||
return apiOptions;
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ export interface ApiLlamaCppServerProps {
|
|||
reasoning_in_content: boolean;
|
||||
thinking_forced_open: boolean;
|
||||
samplers: string[];
|
||||
backend_sampling: boolean;
|
||||
'speculative.n_max': number;
|
||||
'speculative.n_min': number;
|
||||
'speculative.p_min': number;
|
||||
|
|
@ -210,6 +211,7 @@ export interface ApiChatCompletionRequest {
|
|||
dry_penalty_last_n?: number;
|
||||
// Sampler configuration
|
||||
samplers?: string[];
|
||||
backend_sampling?: boolean;
|
||||
// Custom parameters (JSON string)
|
||||
custom?: Record<string, unknown>;
|
||||
timings_per_token?: boolean;
|
||||
|
|
@ -310,6 +312,7 @@ export interface ApiSlotData {
|
|||
reasoning_in_content: boolean;
|
||||
thinking_forced_open: boolean;
|
||||
samplers: string[];
|
||||
backend_sampling: boolean;
|
||||
'speculative.n_max': number;
|
||||
'speculative.n_min': number;
|
||||
'speculative.p_min': number;
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ export interface SettingsChatServiceOptions {
|
|||
dry_penalty_last_n?: number;
|
||||
// Sampler configuration
|
||||
samplers?: string | string[];
|
||||
backend_sampling?: boolean;
|
||||
// Custom parameters
|
||||
custom?: string;
|
||||
timings_per_token?: boolean;
|
||||
|
|
|
|||
Loading…
Reference in New Issue