diff --git a/common/jinja/parser.cpp b/common/jinja/parser.cpp index 7970336ac0..4ae4477445 100644 --- a/common/jinja/parser.cpp +++ b/common/jinja/parser.cpp @@ -53,6 +53,13 @@ private: return tokens[current + offset]; } + const token & next() { + if (current >= tokens.size()) { + throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos); + } + return tokens[current++]; + } + token expect(token::type type, const std::string& error) { const auto & t = peek(); if (t.t != type) { @@ -90,9 +97,9 @@ private: size_t start_pos = current; switch (peek().t) { case token::comment: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::text: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::open_statement: return parse_jinja_statement(); case token::open_expression: @@ -119,8 +126,7 @@ private: } size_t start_pos = current; - std::string name = peek().value; - current++; // consume identifier + std::string name = next().value; statement_ptr result; if (name == "set") { @@ -202,7 +208,7 @@ private: // Ignore generation blocks (transformers-specific) // See https://github.com/huggingface/transformers/pull/30650 for more information. result = mk_stmt(start_pos); - current++; + ++current; } else { throw std::runtime_error("Unknown statement: " + name); @@ -217,7 +223,7 @@ private: statements body; if (is(token::equals)) { - current++; + ++current; value = parse_expression_sequence(); } else { // parsing multiline set here @@ -280,7 +286,7 @@ private: exprs.push_back(primary ? parse_primary_expression() : parse_expression()); bool is_tuple = is(token::comma); while (is(token::comma)) { - current++; // consume comma + ++current; // consume comma exprs.push_back(primary ? parse_primary_expression() : parse_expression()); } return is_tuple ? mk_stmt(start_pos, std::move(exprs)) : std::move(exprs[0]); @@ -290,7 +296,7 @@ private: // e.g., `message` in `for message in messages` auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); - current++; + ++current; // consume 'in' // `messages` in `for message in messages` auto iterable = parse_expression(); @@ -305,7 +311,8 @@ private: } if (is_statement({"else"})) { - current += 2; + ++current; // consume {% + ++current; // consume 'else' expect(token::close_statement, "Expected %}"); while (!is_statement({"endfor"})) { alternate.push_back(parse_any()); @@ -347,7 +354,7 @@ private: auto left = parse_logical_and_expression(); while (is_identifier("or")) { size_t start_pos = current; - token op = tokens[current++]; + token op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_and_expression()); } return left; @@ -357,7 +364,7 @@ private: auto left = parse_logical_negation_expression(); while (is_identifier("and")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_negation_expression()); } return left; @@ -367,7 +374,7 @@ private: // Try parse unary operators if (is_identifier("not")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); return mk_stmt(start_pos, op, parse_logical_negation_expression()); } return parse_comparison_expression(); @@ -382,11 +389,12 @@ private: size_t start_pos = current; if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { op = {token::identifier, "not in", tokens[current].pos}; - current += 2; + ++current; // consume 'not' + ++current; // consume 'in' } else if (is_identifier("in")) { - op = tokens[current++]; + op = next(); } else if (is(token::comparison_binary_operator)) { - op = tokens[current++]; + op = next(); } else break; left = mk_stmt(start_pos, op, std::move(left), parse_additive_expression()); } @@ -397,7 +405,7 @@ private: auto left = parse_multiplicative_expression(); while (is(token::additive_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_multiplicative_expression()); } return left; @@ -407,7 +415,7 @@ private: auto left = parse_test_expression(); while (is(token::multiplicative_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_test_expression()); } return left; @@ -417,9 +425,9 @@ private: auto operand = parse_filter_expression(); while (is_identifier("is")) { size_t start_pos = current; - current++; + ++current; // consume 'is' bool negate = false; - if (is_identifier("not")) { current++; negate = true; } + if (is_identifier("not")) { ++current; negate = true; } auto test_id = parse_primary_expression(); // FIXME: tests can also be expressed like this: if x is eq 3 if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); @@ -432,7 +440,7 @@ private: auto operand = parse_call_member_expression(); while (is(token::pipe)) { size_t start_pos = current; - current++; + ++current; // consume pipe auto filter = parse_primary_expression(); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); operand = mk_stmt(start_pos, std::move(operand), std::move(filter)); @@ -490,7 +498,7 @@ private: statement_ptr parse_member_expression(statement_ptr object) { size_t start_pos = current; while (is(token::dot) || is(token::open_square_bracket)) { - auto op = tokens[current++]; + auto op = next(); bool computed = op.t == token::open_square_bracket; statement_ptr prop; if (computed) { @@ -536,7 +544,7 @@ private: statement_ptr parse_primary_expression() { size_t start_pos = current; - auto t = tokens[current++]; + auto t = next(); switch (t.t) { case token::numeric_literal: if (t.value.find('.') != std::string::npos) { @@ -547,7 +555,7 @@ private: case token::string_literal: { std::string val = t.value; while (is(token::string_literal)) { - val += tokens[current++].value; + val += next().value; } return mk_stmt(start_pos, val); } @@ -562,9 +570,9 @@ private: statements vals; while (!is(token::close_square_bracket)) { vals.push_back(parse_expression()); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(vals)); } case token::open_curly_bracket: { @@ -573,9 +581,9 @@ private: auto key = parse_expression(); expect(token::colon, "Expected :"); pairs.push_back({std::move(key), parse_expression()}); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(pairs)); } default: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dba190b480..0cd47645d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4273,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel): @ModelBase.register("InternVisionModel") class InternVisionModel(MmprojModel): + + min_dynamic_tiles: int = 0 + max_dynamic_tiles: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0) + self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0) + def set_gguf_parameters(self): assert self.hparams_vision is not None if isinstance(self.hparams_vision['image_size'], list): @@ -4295,6 +4305,11 @@ class InternVisionModel(MmprojModel): downsample_ratio = self.global_config.get("downsample_ratio") assert downsample_ratio is not None self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + # older models may not have min/max_dynamic_patch in config + if self.min_dynamic_tiles > 0: + self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles) + if self.max_dynamic_tiles > 0: + self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles) def tensor_force_quant(self, name, new_name, bid, n_dims): if ".position_embd." in new_name: diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e..419862101d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f90..f5d37c7b99 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,16 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c..c59a4db399 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e..f0bd42a576 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496..a25a890db6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43f..024b3d8cf2 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 0000000000..3a2fa99b05 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 0000000000..60f0f6f795 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 0000000000..489e05f08c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 0000000000..6fa3c26d30 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 0000000000..421027fb29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 0000000000..abbc943480 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 0000000000..d641f859d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 0000000000..d1071dc243 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 0000000000..8afda31423 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 0000000000..506864ac18 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 0000000000..0bbda8371e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 0000000000..79be24daf9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 0000000000..45636e5e70 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae2..3b5ab12fc4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ import os HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f96c6e09a9..291b483745 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977..cc53c812ce 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ec1421841..456b1699fa 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (a->ne[3] != b->ne[3]) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } + ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added - return false; - } // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5f92c7700..9383644abf 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -301,6 +301,8 @@ class Keys: IMAGE_SIZE = "clip.vision.image_size" IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" + PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" + PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5f653d386d..010dfeea1c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1156,6 +1156,12 @@ class GGUFWriter: def add_vision_min_pixels(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + def add_vision_preproc_max_tiles(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value) + + def add_vision_preproc_min_tiles(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value) + def add_vision_preproc_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index ef9c8f73c8..1550627bf0 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) { t.test("malformed templates (should error, not crash)", [&](testing & t) { const std::vector malformed = { + "", "{{ x", "{% if %}", "{% for %}", @@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) { for (const auto & tmpl : malformed) { t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); } + std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %"; + while (tmpl.length() > 0) { + t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); + tmpl.pop_back(); + } }); t.test("type coercion edge cases", [&](testing & t) { diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 3eb66f9145..bf55cec7ef 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -38,6 +38,8 @@ #define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" #define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" +#define KEY_PREPROC_MIN_TILES "clip.vision.preproc_min_tiles" +#define KEY_PREPROC_MAX_TILES "clip.vision.preproc_max_tiles" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index eeb8da58e0..265a17130f 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,9 @@ struct clip_hparams { int32_t image_max_pixels = -1; int32_t n_merge = 0; // number of patch merges **per-side** + int32_t preproc_min_tiles = 0; + int32_t preproc_max_tiles = 0; + float image_mean[3]; float image_std[3]; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 44a19189ea..a47f1f495d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1138,6 +1138,16 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_INTERNVL: + { + // older version of internvl doesn't have min/max tiles, we need to provide default values for them to avoid issues + hparams.preproc_min_tiles = 1; + hparams.preproc_max_tiles = 12; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + get_u32(KEY_PREPROC_MIN_TILES, hparams.preproc_min_tiles, false); + get_u32(KEY_PREPROC_MAX_TILES, hparams.preproc_max_tiles, false); + GGML_ASSERT(hparams.preproc_min_tiles <= hparams.preproc_max_tiles && hparams.preproc_max_tiles < INT32_MAX); + set_internvl_dhr_res_candidates(model); + } break; case PROJECTOR_TYPE_NEMOTRON_V2_VL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); @@ -1161,7 +1171,6 @@ struct clip_model_loader { hparams.set_warmup_n_tokens(16*16); } break; case PROJECTOR_TYPE_PIXTRAL: - case PROJECTOR_TYPE_LIGHTONOCR: { // ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json // TODO: verify the image_min_tokens @@ -1171,6 +1180,15 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + hparams.n_merge = 1; + hparams.rope_theta = 10000.0f; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.image_longest_edge = hparams.image_size; + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false); + hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_KIMIVL: { hparams.rope_theta = 10000.0f; @@ -2180,6 +2198,27 @@ struct clip_model_loader { } } } + + static void set_internvl_dhr_res_candidates(clip_model & model) { + auto & hparams = model.hparams; + int min_num = hparams.preproc_min_tiles; + int max_num = hparams.preproc_max_tiles; + if (min_num < 1) { + return; // avoid divide by 0 + } + for (int a = min_num; a <= max_num; ++a) { + int b_lo = (min_num + a - 1) / a; + int b_hi = max_num / a; + b_lo = std::max(b_lo, min_num); + b_hi = std::min(b_hi, max_num); + for (int b = b_lo; b <= b_hi; ++b) { + hparams.image_res_candidates.push_back(clip_image_size { + a*hparams.image_size, + b*hparams.image_size, + }); + } + } + } }; struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { @@ -2726,17 +2765,22 @@ struct llava_uhd { return res; } - static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) { + static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst, bool overview_first = true) { std::vector output; // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview, inst.padding_overview, inst.pad_color_overview); - output.push_back(std::move(resized_img)); + if (overview_first) { + output.push_back(std::move(resized_img)); + } if (inst.slices.empty()) { // no slices, just return the resized image + if (!overview_first) { + output.push_back(std::move(resized_img)); + } return output; } @@ -2757,6 +2801,10 @@ struct llava_uhd { output.push_back(std::move(img_slice)); } + if (!overview_first) { + output.push_back(std::move(resized_img)); + } + return output; } @@ -3141,10 +3189,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = instructions.grid_size.width; res_imgs->grid_y = instructions.grid_size.height; } break; + case PROJECTOR_TYPE_INTERNVL: // support dynamic high-resolution + { + GGML_ASSERT(!params.image_res_candidates.empty()); + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst, false); + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + } break; case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GEMMA3: - case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution case PROJECTOR_TYPE_NEMOTRON_V2_VL: { clip_image_u8 resized_image; @@ -3180,7 +3238,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_PIXTRAL: - case PROJECTOR_TYPE_LIGHTONOCR: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized_image; @@ -3196,6 +3253,19 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + GGML_ASSERT(params.image_longest_edge > 0); + clip_image_u8 resized_image; + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * params.n_merge, + params.image_longest_edge); + img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BICUBIC); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_LLAMA4: { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f66c07345e..456ce7b73c 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -851,13 +851,15 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__); return 1; } + auto proj_type = clip_get_projector_type(ctx_clip); int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) - || clip_is_glm(ctx_clip)) { + || clip_is_glm(ctx_clip) + || proj_type == PROJECTOR_TYPE_INTERNVL) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { diff --git a/tools/server/README.md b/tools/server/README.md index 554444d74b..f584b60026 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1634,6 +1634,13 @@ The `status` object can be: } ``` +```json +"status": { + "value": "sleeping", + "args": ["llama-server", "-ctx", "4096"] +} +``` + ### POST `/models/load`: Load a model Load a model diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9de554e900..b79a5270b5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3033,6 +3033,9 @@ struct server_res_generator : server_http_res { } }; +void server_context::on_sleeping_changed(std::function callback) { + impl->queue_tasks.on_sleeping_state(std::move(callback)); +} // diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 75f3d2de56..a4d2201cbe 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -74,6 +74,10 @@ struct server_context { // get server metadata (read-only), can only be called after load_model() // not thread-safe, should only be used from the main thread server_context_meta get_meta() const; + + // register a callback to be called when sleeping state changes + // must be set before load_model() is called + void on_sleeping_changed(std::function callback); }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 4ac55cd158..7e61844f08 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -39,7 +39,8 @@ extern char **environ; #define DEFAULT_STOP_TIMEOUT 10 // seconds #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" -#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" +#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep +#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep" // address for child process, this is needed because router may run on 0.0.0.0 // ref: https://github.com/ggml-org/llama.cpp/issues/17862 @@ -380,7 +381,7 @@ void server_models::update_meta(const std::string & name, const server_model_met if (it != mapping.end()) { it->second.meta = meta; } - cv.notify_all(); // notify wait_until_loaded + cv.notify_all(); // notify wait_until_loading_finished } bool server_models::has_model(const std::string & name) { @@ -503,7 +504,7 @@ void server_models::unload_lru() { { std::unique_lock lk(mutex); for (const auto & m : mapping) { - if (m.second.meta.is_active()) { + if (m.second.meta.is_running()) { count_active++; if (m.second.meta.last_used < lru_last_used) { lru_model_name = m.first; @@ -546,7 +547,7 @@ void server_models::load(const std::string & name) { if (base_params.models_max > 0) { size_t count_active = 0; for (const auto & m : mapping) { - if (m.second.meta.is_active()) { + if (m.second.meta.is_running()) { count_active++; } } @@ -605,15 +606,15 @@ void server_models::load(const std::string & name) { std::thread log_thread([&]() { // read stdout/stderr and forward to main server log // also handle status report from child process - bool state_received = false; // true if child state received if (stdout_file) { char buffer[4096]; while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); - if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { - // child process is ready + std::string str(buffer); + if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); - state_received = true; + } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { + this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); } } } else { @@ -706,13 +707,13 @@ void server_models::unload(const std::string & name) { std::lock_guard lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { - if (it->second.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); + if (it->second.meta.is_running()) { + SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); cv_stop.notify_all(); // status change will be handled by the managing thread } else { - SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); + SRV_WRN("model instance name=%s is not running\n", name.c_str()); } } } @@ -722,8 +723,8 @@ void server_models::unload_all() { { std::lock_guard lk(mutex); for (auto & [name, inst] : mapping) { - if (inst.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); + if (inst.meta.is_running()) { + SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); cv_stop.notify_all(); // status change will be handled by the managing thread @@ -750,7 +751,7 @@ void server_models::update_status(const std::string & name, server_model_status cv.notify_all(); } -void server_models::wait_until_loaded(const std::string & name) { +void server_models::wait_until_loading_finished(const std::string & name) { std::unique_lock lk(mutex); cv.wait(lk, [this, &name]() { auto it = mapping.find(name); @@ -761,22 +762,25 @@ void server_models::wait_until_loaded(const std::string & name) { }); } -bool server_models::ensure_model_loaded(const std::string & name) { +bool server_models::ensure_model_ready(const std::string & name) { auto meta = get_meta(name); if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->status == SERVER_MODEL_STATUS_LOADED) { - return false; // already loaded + if (meta->is_ready()) { + return false; // ready for taking requests + } + if (meta->status == SERVER_MODEL_STATUS_SLEEPING) { + return false; // child is sleeping but still running; new request will wake it up } if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); load(name); } - // for loading state + // wait for loading to complete SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); - wait_until_loaded(name); + wait_until_loading_finished(name); // check final status meta = get_meta(name); @@ -792,8 +796,8 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->status != SERVER_MODEL_STATUS_LOADED) { - throw std::invalid_argument("model name=" + name + " is not loaded"); + if (!meta->is_running()) { + throw std::invalid_argument("model name=" + name + " is not running"); } if (update_last_used) { std::unique_lock lk(mutex); @@ -819,6 +823,11 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } +bool server_models::is_child_server() { + const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + return router_port != nullptr; +} + std::thread server_models::setup_child_server(const std::function & shutdown_handler) { // send a notification to the router server that a model instance is ready common_log_pause(common_log_main()); @@ -852,6 +861,13 @@ std::thread server_models::setup_child_server(const std::function & s }); } +void server_models::notify_router_sleeping_state(bool is_sleeping) { + common_log_pause(common_log_main()); + fflush(stdout); + fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY); + fflush(stdout); + common_log_resume(common_log_main()); +} // @@ -881,9 +897,9 @@ static bool router_validate_model(std::string & name, server_models & models, bo // resolve alias to canonical model name name = meta->name; if (models_autoload) { - models.ensure_model_loaded(name); + models.ensure_model_ready(name); } else { - if (meta->status != SERVER_MODEL_STATUS_LOADED) { + if (!meta->is_running()) { res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return false; } @@ -956,8 +972,8 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); return res; } - if (meta->status == SERVER_MODEL_STATUS_LOADED) { - res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + if (meta->is_running()) { + res_err(res, format_error_response("model is already running", ERROR_TYPE_INVALID_REQUEST)); return res; } models.load(meta->name); @@ -1015,8 +1031,8 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (!model->is_active()) { - res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + if (!model->is_running()) { + res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST)); return res; } models.unload(model->name); @@ -1181,7 +1197,8 @@ server_http_proxy::server_http_proxy( continue; } if (key == "Host" || key == "host") { - req.set_header(key, host); + bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80); + req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port)); } else { req.set_header(key, value); } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 2b392f299a..1db34b6c4d 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -14,17 +14,18 @@ /** * state diagram: * - * UNLOADED ──► LOADING ──► LOADED - * ▲ │ │ - * └───failed───┘ │ - * ▲ │ + * UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING + * ▲ │ │ ▲ + * └───failed───┘ │ │ + * ▲ └──sleeping─────┘ * └────────unloaded─────────┘ */ enum server_model_status { // TODO: also add downloading state when the logic is added SERVER_MODEL_STATUS_UNLOADED, SERVER_MODEL_STATUS_LOADING, - SERVER_MODEL_STATUS_LOADED + SERVER_MODEL_STATUS_LOADED, + SERVER_MODEL_STATUS_SLEEPING }; static server_model_status server_model_status_from_string(const std::string & status_str) { @@ -37,6 +38,9 @@ static server_model_status server_model_status_from_string(const std::string & s if (status_str == "loaded") { return SERVER_MODEL_STATUS_LOADED; } + if (status_str == "sleeping") { + return SERVER_MODEL_STATUS_SLEEPING; + } throw std::runtime_error("invalid server model status"); } @@ -45,6 +49,7 @@ static std::string server_model_status_to_string(server_model_status status) { case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; case SERVER_MODEL_STATUS_LOADING: return "loading"; case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; default: return "unknown"; } } @@ -61,8 +66,12 @@ struct server_model_meta { int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown - bool is_active() const { - return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; + bool is_ready() const { + return status == SERVER_MODEL_STATUS_LOADED; + } + + bool is_running() const { + return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING || status == SERVER_MODEL_STATUS_SLEEPING; } bool is_failed() const { @@ -130,19 +139,26 @@ public: void update_status(const std::string & name, server_model_status status, int exit_code); // wait until the model instance is fully loaded (thread-safe) - // return when the model is loaded or failed to load - void wait_until_loaded(const std::string & name); + // return when the model no longer in "loading" state + void wait_until_loading_finished(const std::string & name); - // load the model if not loaded, otherwise do nothing (thread-safe) - // return false if model is already loaded; return true otherwise (meta may need to be refreshed) - bool ensure_model_loaded(const std::string & name); + // ensure the model is in ready state (thread-safe) + // return false if model is ready + // otherwise, load the model and blocking wait until it's ready, then return true (meta may need to be refreshed) + bool ensure_model_ready(const std::string & name); // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // return true if the current process is a child server instance + static bool is_child_server(); + // notify the router server that a model instance is ready // return the monitoring thread (to be joined by the caller) static std::thread setup_child_server(const std::function & shutdown_handler); + + // notify the router server that the sleeping state has changed + static void notify_router_sleeping_state(bool sleeping); }; struct server_models_routes { diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 164f09b195..35f010401f 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -95,11 +95,19 @@ public: callback_update_slots = std::move(callback); } - // Register callback for sleeping state change + // Register callback for sleeping state change; multiple callbacks are allowed // note: when entering sleeping state, the callback is called AFTER sleeping is set to true // when leaving sleeping state, the callback is called BEFORE sleeping is set to false void on_sleeping_state(std::function callback) { - callback_sleeping_state = std::move(callback); + if (callback_sleeping_state) { + auto prev_callback = std::move(callback_sleeping_state); + callback_sleeping_state = [prev_callback, callback](bool sleeping) { + prev_callback(sleeping); + callback(sleeping); + }; + } else { + callback_sleeping_state = std::move(callback); + } } private: diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0bd6fda17d..73e3aa44e0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -259,6 +259,12 @@ int main(int argc, char ** argv) { // load the model LOG_INF("%s: loading model\n", __func__); + if (server_models::is_child_server()) { + ctx_server.on_sleeping_changed([&](bool sleeping) { + server_models::notify_router_sleeping_state(sleeping); + }); + } + if (!ctx_server.load_model(params)) { clean_up(); if (ctx_http.thread.joinable()) { @@ -309,9 +315,8 @@ int main(int argc, char ** argv) { LOG_INF("%s: starting the main loop...\n", __func__); // optionally, notify router server that this instance is ready - const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); std::thread monitor_thread; - if (router_port != nullptr) { + if (server_models::is_child_server()) { monitor_thread = server_models::setup_child_server(shutdown_handler); }