Merge branch 'ggml-org:master' into master

This commit is contained in:
Kusha Gharahi 2026-03-22 20:55:57 -05:00 committed by GitHub
commit bae3440fa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 493 additions and 133 deletions

View File

@ -53,6 +53,13 @@ private:
return tokens[current + offset]; return tokens[current + offset];
} }
const token & next() {
if (current >= tokens.size()) {
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
}
return tokens[current++];
}
token expect(token::type type, const std::string& error) { token expect(token::type type, const std::string& error) {
const auto & t = peek(); const auto & t = peek();
if (t.t != type) { if (t.t != type) {
@ -90,9 +97,9 @@ private:
size_t start_pos = current; size_t start_pos = current;
switch (peek().t) { switch (peek().t) {
case token::comment: case token::comment:
return mk_stmt<comment_statement>(start_pos, tokens[current++].value); return mk_stmt<comment_statement>(start_pos, next().value);
case token::text: case token::text:
return mk_stmt<string_literal>(start_pos, tokens[current++].value); return mk_stmt<string_literal>(start_pos, next().value);
case token::open_statement: case token::open_statement:
return parse_jinja_statement(); return parse_jinja_statement();
case token::open_expression: case token::open_expression:
@ -119,8 +126,7 @@ private:
} }
size_t start_pos = current; size_t start_pos = current;
std::string name = peek().value; std::string name = next().value;
current++; // consume identifier
statement_ptr result; statement_ptr result;
if (name == "set") { if (name == "set") {
@ -202,7 +208,7 @@ private:
// Ignore generation blocks (transformers-specific) // Ignore generation blocks (transformers-specific)
// See https://github.com/huggingface/transformers/pull/30650 for more information. // See https://github.com/huggingface/transformers/pull/30650 for more information.
result = mk_stmt<noop_statement>(start_pos); result = mk_stmt<noop_statement>(start_pos);
current++; ++current;
} else { } else {
throw std::runtime_error("Unknown statement: " + name); throw std::runtime_error("Unknown statement: " + name);
@ -217,7 +223,7 @@ private:
statements body; statements body;
if (is(token::equals)) { if (is(token::equals)) {
current++; ++current;
value = parse_expression_sequence(); value = parse_expression_sequence();
} else { } else {
// parsing multiline set here // parsing multiline set here
@ -280,7 +286,7 @@ private:
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
bool is_tuple = is(token::comma); bool is_tuple = is(token::comma);
while (is(token::comma)) { while (is(token::comma)) {
current++; // consume comma ++current; // consume comma
exprs.push_back(primary ? parse_primary_expression() : parse_expression()); exprs.push_back(primary ? parse_primary_expression() : parse_expression());
} }
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]); return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@ -290,7 +296,7 @@ private:
// e.g., `message` in `for message in messages` // e.g., `message` in `for message in messages`
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
current++; ++current; // consume 'in'
// `messages` in `for message in messages` // `messages` in `for message in messages`
auto iterable = parse_expression(); auto iterable = parse_expression();
@ -305,7 +311,8 @@ private:
} }
if (is_statement({"else"})) { if (is_statement({"else"})) {
current += 2; ++current; // consume {%
++current; // consume 'else'
expect(token::close_statement, "Expected %}"); expect(token::close_statement, "Expected %}");
while (!is_statement({"endfor"})) { while (!is_statement({"endfor"})) {
alternate.push_back(parse_any()); alternate.push_back(parse_any());
@ -347,7 +354,7 @@ private:
auto left = parse_logical_and_expression(); auto left = parse_logical_and_expression();
while (is_identifier("or")) { while (is_identifier("or")) {
size_t start_pos = current; size_t start_pos = current;
token op = tokens[current++]; token op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
} }
return left; return left;
@ -357,7 +364,7 @@ private:
auto left = parse_logical_negation_expression(); auto left = parse_logical_negation_expression();
while (is_identifier("and")) { while (is_identifier("and")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
} }
return left; return left;
@ -367,7 +374,7 @@ private:
// Try parse unary operators // Try parse unary operators
if (is_identifier("not")) { if (is_identifier("not")) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression()); return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
} }
return parse_comparison_expression(); return parse_comparison_expression();
@ -382,11 +389,12 @@ private:
size_t start_pos = current; size_t start_pos = current;
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
op = {token::identifier, "not in", tokens[current].pos}; op = {token::identifier, "not in", tokens[current].pos};
current += 2; ++current; // consume 'not'
++current; // consume 'in'
} else if (is_identifier("in")) { } else if (is_identifier("in")) {
op = tokens[current++]; op = next();
} else if (is(token::comparison_binary_operator)) { } else if (is(token::comparison_binary_operator)) {
op = tokens[current++]; op = next();
} else break; } else break;
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
} }
@ -397,7 +405,7 @@ private:
auto left = parse_multiplicative_expression(); auto left = parse_multiplicative_expression();
while (is(token::additive_binary_operator)) { while (is(token::additive_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
} }
return left; return left;
@ -407,7 +415,7 @@ private:
auto left = parse_test_expression(); auto left = parse_test_expression();
while (is(token::multiplicative_binary_operator)) { while (is(token::multiplicative_binary_operator)) {
size_t start_pos = current; size_t start_pos = current;
auto op = tokens[current++]; auto op = next();
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression()); left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
} }
return left; return left;
@ -417,9 +425,9 @@ private:
auto operand = parse_filter_expression(); auto operand = parse_filter_expression();
while (is_identifier("is")) { while (is_identifier("is")) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume 'is'
bool negate = false; bool negate = false;
if (is_identifier("not")) { current++; negate = true; } if (is_identifier("not")) { ++current; negate = true; }
auto test_id = parse_primary_expression(); auto test_id = parse_primary_expression();
// FIXME: tests can also be expressed like this: if x is eq 3 // FIXME: tests can also be expressed like this: if x is eq 3
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@ -432,7 +440,7 @@ private:
auto operand = parse_call_member_expression(); auto operand = parse_call_member_expression();
while (is(token::pipe)) { while (is(token::pipe)) {
size_t start_pos = current; size_t start_pos = current;
current++; ++current; // consume pipe
auto filter = parse_primary_expression(); auto filter = parse_primary_expression();
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter)); operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@ -490,7 +498,7 @@ private:
statement_ptr parse_member_expression(statement_ptr object) { statement_ptr parse_member_expression(statement_ptr object) {
size_t start_pos = current; size_t start_pos = current;
while (is(token::dot) || is(token::open_square_bracket)) { while (is(token::dot) || is(token::open_square_bracket)) {
auto op = tokens[current++]; auto op = next();
bool computed = op.t == token::open_square_bracket; bool computed = op.t == token::open_square_bracket;
statement_ptr prop; statement_ptr prop;
if (computed) { if (computed) {
@ -536,7 +544,7 @@ private:
statement_ptr parse_primary_expression() { statement_ptr parse_primary_expression() {
size_t start_pos = current; size_t start_pos = current;
auto t = tokens[current++]; auto t = next();
switch (t.t) { switch (t.t) {
case token::numeric_literal: case token::numeric_literal:
if (t.value.find('.') != std::string::npos) { if (t.value.find('.') != std::string::npos) {
@ -547,7 +555,7 @@ private:
case token::string_literal: { case token::string_literal: {
std::string val = t.value; std::string val = t.value;
while (is(token::string_literal)) { while (is(token::string_literal)) {
val += tokens[current++].value; val += next().value;
} }
return mk_stmt<string_literal>(start_pos, val); return mk_stmt<string_literal>(start_pos, val);
} }
@ -562,9 +570,9 @@ private:
statements vals; statements vals;
while (!is(token::close_square_bracket)) { while (!is(token::close_square_bracket)) {
vals.push_back(parse_expression()); vals.push_back(parse_expression());
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<array_literal>(start_pos, std::move(vals)); return mk_stmt<array_literal>(start_pos, std::move(vals));
} }
case token::open_curly_bracket: { case token::open_curly_bracket: {
@ -573,9 +581,9 @@ private:
auto key = parse_expression(); auto key = parse_expression();
expect(token::colon, "Expected :"); expect(token::colon, "Expected :");
pairs.push_back({std::move(key), parse_expression()}); pairs.push_back({std::move(key), parse_expression()});
if (is(token::comma)) current++; if (is(token::comma)) ++current;
} }
current++; ++current;
return mk_stmt<object_literal>(start_pos, std::move(pairs)); return mk_stmt<object_literal>(start_pos, std::move(pairs));
} }
default: default:

View File

@ -4273,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel") @ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel): class InternVisionModel(MmprojModel):
min_dynamic_tiles: int = 0
max_dynamic_tiles: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0)
self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0)
def set_gguf_parameters(self): def set_gguf_parameters(self):
assert self.hparams_vision is not None assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list): if isinstance(self.hparams_vision['image_size'], list):
@ -4295,6 +4305,11 @@ class InternVisionModel(MmprojModel):
downsample_ratio = self.global_config.get("downsample_ratio") downsample_ratio = self.global_config.get("downsample_ratio")
assert downsample_ratio is not None assert downsample_ratio is not None
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
# older models may not have min/max_dynamic_patch in config
if self.min_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles)
if self.max_dynamic_tiles > 0:
self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles)
def tensor_force_quant(self, name, new_name, bid, n_dims): def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name: if ".position_embd." in new_name:

View File

@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS}) list(APPEND GGML_SOURCES_CUDA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else() else()
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") list(APPEND GGML_SOURCES_CUDA
list(APPEND GGML_SOURCES_CUDA ${SRCS}) template-instances/fattn-vec-instance-f16-f16.cu
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") template-instances/fattn-vec-instance-q4_0-q4_0.cu
list(APPEND GGML_SOURCES_CUDA ${SRCS}) template-instances/fattn-vec-instance-q8_0-q8_0.cu
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") template-instances/fattn-vec-instance-bf16-bf16.cu)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif() endif()
ggml_add_backend_library(ggml-cuda ggml_add_backend_library(ggml-cuda

View File

@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
return __bfloat162float(x); return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) { } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x); return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
#ifdef GGML_USE_HIP
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
#else
#if __CUDA_ARCH__ >= 800
return __bfloat1622float2(x);
#else
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
#endif // __CUDA_ARCH__ >= 800
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) { } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1 // bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP #ifdef GGML_USE_HIP

View File

@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
return sum; return sum;
} }
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
__align__(16) nv_bfloat162 tmp[cpy_ne];
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef V_DOT2_F32_F16_AVAILABLE
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
#else
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
return sum;
}
template<int D, int nthreads> template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
} }
} }
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
static_assert(ne % 2 == 0, "bad ne");
__align__(16) nv_bfloat162 tmp[ne/2];
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
float2 * dst_f2 = (float2 *) dst;
#pragma unroll
for (int l = 0; l < ne/2; ++l) {
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
}
}
template <typename T, int ne> template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>; return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_Q8_0) { } else if constexpr (type_K == GGML_TYPE_Q8_0) {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>; return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else { } else {
static_assert(type_K == -1, "bad type"); static_assert(type_K == -1, "bad type");
return nullptr; return nullptr;
@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q5_1<T, ne>; return dequantize_V_q5_1<T, ne>;
} else if constexpr (type_V == GGML_TYPE_Q8_0) { } else if constexpr (type_V == GGML_TYPE_Q8_0) {
return dequantize_V_q8_0<T, ne>; return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else { } else {
static_assert(type_V == -1, "bad type"); static_assert(type_V == -1, "bad type");
return nullptr; return nullptr;

View File

@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
#endif // GGML_USE_HIP #endif // GGML_USE_HIP
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
#ifdef V_DOT2_F32_F16_AVAILABLE #ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else #else
@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
#pragma unroll #pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
half2 tmp[V_rows_per_thread/2]; half2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp, if constexpr (type_V == GGML_TYPE_BF16) {
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); float2 tmp_f[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp_f,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
}
} else {
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
}
#pragma unroll #pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll #pragma unroll
@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

View File

@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#else #else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS #endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
break; break;
default: default:
return BEST_FATTN_KERNEL_NONE; return BEST_FATTN_KERNEL_NONE;

View File

@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
} }
} }
static constexpr __device__ int get_vdr_mmvq(ggml_type type) { static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
return 1; return 1;
} }
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1:
return 1; return small_k ? nwarps : 1;
case 2: case 2:
case 3: case 3:
case 4: case 4:
@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1; return 1;
} }
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false> template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
constexpr int vdr = get_vdr_mmvq(type); constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
template<ggml_type type> template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params( static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) { const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); const int nwarps = calc_nwarps(type, ncols_dst, table_id);
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims}; return {block_nums, block_dims};
} }
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false> template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion( static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) { if constexpr (c_ncols_dst == 1) {
if (has_fusion) { if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) { switch (ncols_dst) {
case 1: { case 1: {
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, constexpr int qk = ggml_cuda_type_traits<type>::qk;
dims.first, dims.second, 0, ids_stride, stream); constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
}
} break; } break;
case 2: { case 2: {
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);

View File

@ -0,0 +1,7 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);

View File

@ -5,7 +5,7 @@ import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.

View File

@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS)
list(APPEND GGML_SOURCES_ROCM ${SRCS}) list(APPEND GGML_SOURCES_ROCM ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else() else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") list(APPEND GGML_SOURCES_ROCM
list(APPEND GGML_SOURCES_ROCM ${SRCS}) ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
list(APPEND GGML_SOURCES_ROCM ${SRCS}) ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
list(APPEND GGML_SOURCES_ROCM ${SRCS})
endif() endif()
ggml_add_backend_library(ggml-hip ggml_add_backend_library(ggml-hip

View File

@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_SOURCES_MUSA ${SRCS}) list(APPEND GGML_SOURCES_MUSA ${SRCS})
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
else() else()
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") list(APPEND GGML_SOURCES_MUSA
list(APPEND GGML_SOURCES_MUSA ${SRCS}) ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
list(APPEND GGML_SOURCES_MUSA ${SRCS}) ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
list(APPEND GGML_SOURCES_MUSA ${SRCS})
endif() endif()
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)

View File

@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (a->ne[3] != b->ne[3]) { if (a->ne[3] != b->ne[3]) {
return false; return false;
} }
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}
}
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16 ) {
// TODO: support GGML_TYPE_BF16
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
// TODO: The configuration below needs more work to be supported with oneDNN // TODO: The configuration below needs more work to be supported with oneDNN
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&

View File

@ -301,6 +301,8 @@ class Keys:
IMAGE_SIZE = "clip.vision.image_size" IMAGE_SIZE = "clip.vision.image_size"
IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels"
IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels"
PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles"
PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size" PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length" EMBEDDING_LENGTH = "clip.vision.embedding_length"

View File

@ -1156,6 +1156,12 @@ class GGUFWriter:
def add_vision_min_pixels(self, value: int) -> None: def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
def add_vision_preproc_max_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value)
def add_vision_preproc_min_tiles(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value)
def add_vision_preproc_image_size(self, value: int) -> None: def add_vision_preproc_image_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)

View File

@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) {
t.test("malformed templates (should error, not crash)", [&](testing & t) { t.test("malformed templates (should error, not crash)", [&](testing & t) {
const std::vector<std::string> malformed = { const std::vector<std::string> malformed = {
"",
"{{ x", "{{ x",
"{% if %}", "{% if %}",
"{% for %}", "{% for %}",
@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) {
for (const auto & tmpl : malformed) { for (const auto & tmpl : malformed) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
} }
std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %";
while (tmpl.length() > 0) {
t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
tmpl.pop_back();
}
}); });
t.test("type coercion edge cases", [&](testing & t) { t.test("type coercion edge cases", [&](testing & t) {

View File

@ -38,6 +38,8 @@
#define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" #define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels"
#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" #define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels"
#define KEY_PREPROC_MIN_TILES "clip.vision.preproc_min_tiles"
#define KEY_PREPROC_MAX_TILES "clip.vision.preproc_max_tiles"
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_MEAN "clip.vision.image_mean"

View File

@ -42,6 +42,9 @@ struct clip_hparams {
int32_t image_max_pixels = -1; int32_t image_max_pixels = -1;
int32_t n_merge = 0; // number of patch merges **per-side** int32_t n_merge = 0; // number of patch merges **per-side**
int32_t preproc_min_tiles = 0;
int32_t preproc_max_tiles = 0;
float image_mean[3]; float image_mean[3];
float image_std[3]; float image_std[3];

View File

@ -1138,6 +1138,16 @@ struct clip_model_loader {
} }
} break; } break;
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
{
// older version of internvl doesn't have min/max tiles, we need to provide default values for them to avoid issues
hparams.preproc_min_tiles = 1;
hparams.preproc_max_tiles = 12;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
get_u32(KEY_PREPROC_MIN_TILES, hparams.preproc_min_tiles, false);
get_u32(KEY_PREPROC_MAX_TILES, hparams.preproc_max_tiles, false);
GGML_ASSERT(hparams.preproc_min_tiles <= hparams.preproc_max_tiles && hparams.preproc_max_tiles < INT32_MAX);
set_internvl_dhr_res_candidates(model);
} break;
case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_NEMOTRON_V2_VL:
{ {
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
@ -1161,7 +1171,6 @@ struct clip_model_loader {
hparams.set_warmup_n_tokens(16*16); hparams.set_warmup_n_tokens(16*16);
} break; } break;
case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{ {
// ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json // ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json
// TODO: verify the image_min_tokens // TODO: verify the image_min_tokens
@ -1171,6 +1180,15 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(8, 1024); hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break; } break;
case PROJECTOR_TYPE_LIGHTONOCR:
{
hparams.n_merge = 1;
hparams.rope_theta = 10000.0f;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
hparams.image_longest_edge = hparams.image_size;
get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_KIMIVL:
{ {
hparams.rope_theta = 10000.0f; hparams.rope_theta = 10000.0f;
@ -2180,6 +2198,27 @@ struct clip_model_loader {
} }
} }
} }
static void set_internvl_dhr_res_candidates(clip_model & model) {
auto & hparams = model.hparams;
int min_num = hparams.preproc_min_tiles;
int max_num = hparams.preproc_max_tiles;
if (min_num < 1) {
return; // avoid divide by 0
}
for (int a = min_num; a <= max_num; ++a) {
int b_lo = (min_num + a - 1) / a;
int b_hi = max_num / a;
b_lo = std::max(b_lo, min_num);
b_hi = std::min(b_hi, max_num);
for (int b = b_lo; b <= b_hi; ++b) {
hparams.image_res_candidates.push_back(clip_image_size {
a*hparams.image_size,
b*hparams.image_size,
});
}
}
}
}; };
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
@ -2726,17 +2765,22 @@ struct llava_uhd {
return res; return res;
} }
static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) { static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst, bool overview_first = true) {
std::vector<clip_image_u8_ptr> output; std::vector<clip_image_u8_ptr> output;
// resize to overview size // resize to overview size
clip_image_u8_ptr resized_img(clip_image_u8_init()); clip_image_u8_ptr resized_img(clip_image_u8_init());
img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview, img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview,
inst.padding_overview, inst.pad_color_overview); inst.padding_overview, inst.pad_color_overview);
output.push_back(std::move(resized_img)); if (overview_first) {
output.push_back(std::move(resized_img));
}
if (inst.slices.empty()) { if (inst.slices.empty()) {
// no slices, just return the resized image // no slices, just return the resized image
if (!overview_first) {
output.push_back(std::move(resized_img));
}
return output; return output;
} }
@ -2757,6 +2801,10 @@ struct llava_uhd {
output.push_back(std::move(img_slice)); output.push_back(std::move(img_slice));
} }
if (!overview_first) {
output.push_back(std::move(resized_img));
}
return output; return output;
} }
@ -3141,10 +3189,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_x = instructions.grid_size.width; res_imgs->grid_x = instructions.grid_size.width;
res_imgs->grid_y = instructions.grid_size.height; res_imgs->grid_y = instructions.grid_size.height;
} break; } break;
case PROJECTOR_TYPE_INTERNVL: // support dynamic high-resolution
{
GGML_ASSERT(!params.image_res_candidates.empty());
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst, false);
for (size_t i = 0; i < imgs.size(); ++i) {
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
}
} break;
case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GLM_EDGE:
case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_NEMOTRON_V2_VL:
{ {
clip_image_u8 resized_image; clip_image_u8 resized_image;
@ -3180,7 +3238,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_PHI4:
case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{ {
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
clip_image_u8 resized_image; clip_image_u8 resized_image;
@ -3196,6 +3253,19 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32)); res_imgs->entries.push_back(std::move(img_f32));
} break; } break;
case PROJECTOR_TYPE_LIGHTONOCR:
{
GGML_ASSERT(params.image_longest_edge > 0);
clip_image_u8 resized_image;
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
params.patch_size * params.n_merge,
params.image_longest_edge);
img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BICUBIC);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32));
} break;
case PROJECTOR_TYPE_LLAMA4: case PROJECTOR_TYPE_LLAMA4:
{ {

View File

@ -851,13 +851,15 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__); LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__);
return 1; return 1;
} }
auto proj_type = clip_get_projector_type(ctx_clip);
int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip); int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip);
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
bool ok = false; bool ok = false;
if (clip_is_llava(ctx_clip) if (clip_is_llava(ctx_clip)
|| clip_is_minicpmv(ctx_clip) || clip_is_minicpmv(ctx_clip)
|| clip_is_glm(ctx_clip)) { || clip_is_glm(ctx_clip)
|| proj_type == PROJECTOR_TYPE_INTERNVL) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries; const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) { for (size_t i = 0; i < entries.size(); i++) {

View File

@ -1634,6 +1634,13 @@ The `status` object can be:
} }
``` ```
```json
"status": {
"value": "sleeping",
"args": ["llama-server", "-ctx", "4096"]
}
```
### POST `/models/load`: Load a model ### POST `/models/load`: Load a model
Load a model Load a model

View File

@ -3033,6 +3033,9 @@ struct server_res_generator : server_http_res {
} }
}; };
void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
impl->queue_tasks.on_sleeping_state(std::move(callback));
}
// //

View File

@ -74,6 +74,10 @@ struct server_context {
// get server metadata (read-only), can only be called after load_model() // get server metadata (read-only), can only be called after load_model()
// not thread-safe, should only be used from the main thread // not thread-safe, should only be used from the main thread
server_context_meta get_meta() const; server_context_meta get_meta() const;
// register a callback to be called when sleeping state changes
// must be set before load_model() is called
void on_sleeping_changed(std::function<void(bool)> callback);
}; };

View File

@ -39,7 +39,8 @@ extern char **environ;
#define DEFAULT_STOP_TIMEOUT 10 // seconds #define DEFAULT_STOP_TIMEOUT 10 // seconds
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" #define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep
#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep"
// address for child process, this is needed because router may run on 0.0.0.0 // address for child process, this is needed because router may run on 0.0.0.0
// ref: https://github.com/ggml-org/llama.cpp/issues/17862 // ref: https://github.com/ggml-org/llama.cpp/issues/17862
@ -380,7 +381,7 @@ void server_models::update_meta(const std::string & name, const server_model_met
if (it != mapping.end()) { if (it != mapping.end()) {
it->second.meta = meta; it->second.meta = meta;
} }
cv.notify_all(); // notify wait_until_loaded cv.notify_all(); // notify wait_until_loading_finished
} }
bool server_models::has_model(const std::string & name) { bool server_models::has_model(const std::string & name) {
@ -503,7 +504,7 @@ void server_models::unload_lru() {
{ {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
for (const auto & m : mapping) { for (const auto & m : mapping) {
if (m.second.meta.is_active()) { if (m.second.meta.is_running()) {
count_active++; count_active++;
if (m.second.meta.last_used < lru_last_used) { if (m.second.meta.last_used < lru_last_used) {
lru_model_name = m.first; lru_model_name = m.first;
@ -546,7 +547,7 @@ void server_models::load(const std::string & name) {
if (base_params.models_max > 0) { if (base_params.models_max > 0) {
size_t count_active = 0; size_t count_active = 0;
for (const auto & m : mapping) { for (const auto & m : mapping) {
if (m.second.meta.is_active()) { if (m.second.meta.is_running()) {
count_active++; count_active++;
} }
} }
@ -605,15 +606,15 @@ void server_models::load(const std::string & name) {
std::thread log_thread([&]() { std::thread log_thread([&]() {
// read stdout/stderr and forward to main server log // read stdout/stderr and forward to main server log
// also handle status report from child process // also handle status report from child process
bool state_received = false; // true if child state received
if (stdout_file) { if (stdout_file) {
char buffer[4096]; char buffer[4096];
while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
LOG("[%5d] %s", port, buffer); LOG("[%5d] %s", port, buffer);
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { std::string str(buffer);
// child process is ready if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) {
this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
state_received = true; } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) {
this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0);
} }
} }
} else { } else {
@ -706,13 +707,13 @@ void server_models::unload(const std::string & name) {
std::lock_guard<std::mutex> lk(mutex); std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name); auto it = mapping.find(name);
if (it != mapping.end()) { if (it != mapping.end()) {
if (it->second.meta.is_active()) { if (it->second.meta.is_running()) {
SRV_INF("unloading model instance name=%s\n", name.c_str()); SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name); stopping_models.insert(name);
cv_stop.notify_all(); cv_stop.notify_all();
// status change will be handled by the managing thread // status change will be handled by the managing thread
} else { } else {
SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); SRV_WRN("model instance name=%s is not running\n", name.c_str());
} }
} }
} }
@ -722,8 +723,8 @@ void server_models::unload_all() {
{ {
std::lock_guard<std::mutex> lk(mutex); std::lock_guard<std::mutex> lk(mutex);
for (auto & [name, inst] : mapping) { for (auto & [name, inst] : mapping) {
if (inst.meta.is_active()) { if (inst.meta.is_running()) {
SRV_INF("unloading model instance name=%s\n", name.c_str()); SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name); stopping_models.insert(name);
cv_stop.notify_all(); cv_stop.notify_all();
// status change will be handled by the managing thread // status change will be handled by the managing thread
@ -750,7 +751,7 @@ void server_models::update_status(const std::string & name, server_model_status
cv.notify_all(); cv.notify_all();
} }
void server_models::wait_until_loaded(const std::string & name) { void server_models::wait_until_loading_finished(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
cv.wait(lk, [this, &name]() { cv.wait(lk, [this, &name]() {
auto it = mapping.find(name); auto it = mapping.find(name);
@ -761,22 +762,25 @@ void server_models::wait_until_loaded(const std::string & name) {
}); });
} }
bool server_models::ensure_model_loaded(const std::string & name) { bool server_models::ensure_model_ready(const std::string & name) {
auto meta = get_meta(name); auto meta = get_meta(name);
if (!meta.has_value()) { if (!meta.has_value()) {
throw std::runtime_error("model name=" + name + " is not found"); throw std::runtime_error("model name=" + name + " is not found");
} }
if (meta->status == SERVER_MODEL_STATUS_LOADED) { if (meta->is_ready()) {
return false; // already loaded return false; // ready for taking requests
}
if (meta->status == SERVER_MODEL_STATUS_SLEEPING) {
return false; // child is sleeping but still running; new request will wake it up
} }
if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
load(name); load(name);
} }
// for loading state // wait for loading to complete
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
wait_until_loaded(name); wait_until_loading_finished(name);
// check final status // check final status
meta = get_meta(name); meta = get_meta(name);
@ -792,8 +796,8 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
if (!meta.has_value()) { if (!meta.has_value()) {
throw std::runtime_error("model name=" + name + " is not found"); throw std::runtime_error("model name=" + name + " is not found");
} }
if (meta->status != SERVER_MODEL_STATUS_LOADED) { if (!meta->is_running()) {
throw std::invalid_argument("model name=" + name + " is not loaded"); throw std::invalid_argument("model name=" + name + " is not running");
} }
if (update_last_used) { if (update_last_used) {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
@ -819,6 +823,11 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
return proxy; return proxy;
} }
bool server_models::is_child_server() {
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
return router_port != nullptr;
}
std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) { std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
// send a notification to the router server that a model instance is ready // send a notification to the router server that a model instance is ready
common_log_pause(common_log_main()); common_log_pause(common_log_main());
@ -852,6 +861,13 @@ std::thread server_models::setup_child_server(const std::function<void(int)> & s
}); });
} }
void server_models::notify_router_sleeping_state(bool is_sleeping) {
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY);
fflush(stdout);
common_log_resume(common_log_main());
}
// //
@ -881,9 +897,9 @@ static bool router_validate_model(std::string & name, server_models & models, bo
// resolve alias to canonical model name // resolve alias to canonical model name
name = meta->name; name = meta->name;
if (models_autoload) { if (models_autoload) {
models.ensure_model_loaded(name); models.ensure_model_ready(name);
} else { } else {
if (meta->status != SERVER_MODEL_STATUS_LOADED) { if (!meta->is_running()) {
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
return false; return false;
} }
@ -956,8 +972,8 @@ void server_models_routes::init_routes() {
res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
return res; return res;
} }
if (meta->status == SERVER_MODEL_STATUS_LOADED) { if (meta->is_running()) {
res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is already running", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
models.load(meta->name); models.load(meta->name);
@ -1015,8 +1031,8 @@ void server_models_routes::init_routes() {
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
if (!model->is_active()) { if (!model->is_running()) {
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
models.unload(model->name); models.unload(model->name);
@ -1181,7 +1197,8 @@ server_http_proxy::server_http_proxy(
continue; continue;
} }
if (key == "Host" || key == "host") { if (key == "Host" || key == "host") {
req.set_header(key, host); bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
} else { } else {
req.set_header(key, value); req.set_header(key, value);
} }

View File

@ -14,17 +14,18 @@
/** /**
* state diagram: * state diagram:
* *
* UNLOADED LOADING LOADED * UNLOADED LOADING LOADED SLEEPING
* *
* failed * failed
* * sleeping
* unloaded * unloaded
*/ */
enum server_model_status { enum server_model_status {
// TODO: also add downloading state when the logic is added // TODO: also add downloading state when the logic is added
SERVER_MODEL_STATUS_UNLOADED, SERVER_MODEL_STATUS_UNLOADED,
SERVER_MODEL_STATUS_LOADING, SERVER_MODEL_STATUS_LOADING,
SERVER_MODEL_STATUS_LOADED SERVER_MODEL_STATUS_LOADED,
SERVER_MODEL_STATUS_SLEEPING
}; };
static server_model_status server_model_status_from_string(const std::string & status_str) { static server_model_status server_model_status_from_string(const std::string & status_str) {
@ -37,6 +38,9 @@ static server_model_status server_model_status_from_string(const std::string & s
if (status_str == "loaded") { if (status_str == "loaded") {
return SERVER_MODEL_STATUS_LOADED; return SERVER_MODEL_STATUS_LOADED;
} }
if (status_str == "sleeping") {
return SERVER_MODEL_STATUS_SLEEPING;
}
throw std::runtime_error("invalid server model status"); throw std::runtime_error("invalid server model status");
} }
@ -45,6 +49,7 @@ static std::string server_model_status_to_string(server_model_status status) {
case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
case SERVER_MODEL_STATUS_LOADING: return "loading"; case SERVER_MODEL_STATUS_LOADING: return "loading";
case SERVER_MODEL_STATUS_LOADED: return "loaded"; case SERVER_MODEL_STATUS_LOADED: return "loaded";
case SERVER_MODEL_STATUS_SLEEPING: return "sleeping";
default: return "unknown"; default: return "unknown";
} }
} }
@ -61,8 +66,12 @@ struct server_model_meta {
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
bool is_active() const { bool is_ready() const {
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; return status == SERVER_MODEL_STATUS_LOADED;
}
bool is_running() const {
return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING || status == SERVER_MODEL_STATUS_SLEEPING;
} }
bool is_failed() const { bool is_failed() const {
@ -130,19 +139,26 @@ public:
void update_status(const std::string & name, server_model_status status, int exit_code); void update_status(const std::string & name, server_model_status status, int exit_code);
// wait until the model instance is fully loaded (thread-safe) // wait until the model instance is fully loaded (thread-safe)
// return when the model is loaded or failed to load // return when the model no longer in "loading" state
void wait_until_loaded(const std::string & name); void wait_until_loading_finished(const std::string & name);
// load the model if not loaded, otherwise do nothing (thread-safe) // ensure the model is in ready state (thread-safe)
// return false if model is already loaded; return true otherwise (meta may need to be refreshed) // return false if model is ready
bool ensure_model_loaded(const std::string & name); // otherwise, load the model and blocking wait until it's ready, then return true (meta may need to be refreshed)
bool ensure_model_ready(const std::string & name);
// proxy an HTTP request to the model instance // proxy an HTTP request to the model instance
server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
// return true if the current process is a child server instance
static bool is_child_server();
// notify the router server that a model instance is ready // notify the router server that a model instance is ready
// return the monitoring thread (to be joined by the caller) // return the monitoring thread (to be joined by the caller)
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler); static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
// notify the router server that the sleeping state has changed
static void notify_router_sleeping_state(bool sleeping);
}; };
struct server_models_routes { struct server_models_routes {

View File

@ -95,11 +95,19 @@ public:
callback_update_slots = std::move(callback); callback_update_slots = std::move(callback);
} }
// Register callback for sleeping state change // Register callback for sleeping state change; multiple callbacks are allowed
// note: when entering sleeping state, the callback is called AFTER sleeping is set to true // note: when entering sleeping state, the callback is called AFTER sleeping is set to true
// when leaving sleeping state, the callback is called BEFORE sleeping is set to false // when leaving sleeping state, the callback is called BEFORE sleeping is set to false
void on_sleeping_state(std::function<void(bool)> callback) { void on_sleeping_state(std::function<void(bool)> callback) {
callback_sleeping_state = std::move(callback); if (callback_sleeping_state) {
auto prev_callback = std::move(callback_sleeping_state);
callback_sleeping_state = [prev_callback, callback](bool sleeping) {
prev_callback(sleeping);
callback(sleeping);
};
} else {
callback_sleeping_state = std::move(callback);
}
} }
private: private:

View File

@ -259,6 +259,12 @@ int main(int argc, char ** argv) {
// load the model // load the model
LOG_INF("%s: loading model\n", __func__); LOG_INF("%s: loading model\n", __func__);
if (server_models::is_child_server()) {
ctx_server.on_sleeping_changed([&](bool sleeping) {
server_models::notify_router_sleeping_state(sleeping);
});
}
if (!ctx_server.load_model(params)) { if (!ctx_server.load_model(params)) {
clean_up(); clean_up();
if (ctx_http.thread.joinable()) { if (ctx_http.thread.joinable()) {
@ -309,9 +315,8 @@ int main(int argc, char ** argv) {
LOG_INF("%s: starting the main loop...\n", __func__); LOG_INF("%s: starting the main loop...\n", __func__);
// optionally, notify router server that this instance is ready // optionally, notify router server that this instance is ready
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
std::thread monitor_thread; std::thread monitor_thread;
if (router_port != nullptr) { if (server_models::is_child_server()) {
monitor_thread = server_models::setup_child_server(shutdown_handler); monitor_thread = server_models::setup_child_server(shutdown_handler);
} }