Merge branch 'master' into xsn/mtmd_custom_min_max_tokens

This commit is contained in:
Xuan Son Nguyen 2025-11-02 22:14:03 +01:00
commit 79b98dbf96
43 changed files with 909 additions and 182 deletions

View File

@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DLLAMA_BUILD_TESTS=OFF \
-DGGML_BACKEND_DL=OFF \
-DGGML_NATIVE=OFF \
-DGGML_BACKEND_DL=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS && \
cmake --build build --config Release -j $(nproc) && \
@ -103,6 +104,7 @@ FROM base AS light
WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0
WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
EXPOSE 8080

View File

@ -134,8 +134,8 @@ jobs:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 's390x-z15' # z15 because our CI runners are on z15
os: ubuntu-22.04-s390x
- build: 's390x'
os: ubuntu-24.04-s390x
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
# - build: 'arm64'
# os: ubuntu-22.04-arm

View File

@ -313,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
@ -1810,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["thinking"] = msg.at("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
// Check if we need to replace the return token with end token during
// inference and without generation prompt. For more details see:

View File

@ -9802,6 +9802,113 @@ class CogVLMModel(LlamaModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("JanusForConditionalGeneration")
class JanusProModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision, aligner, and generation tensors
skip_prefixes = (
'model.vision_model.',
'model.aligner.',
'model.vqmodel.',
'model.generation_embeddings.',
'model.generation_aligner.',
'model.generation_head.',
)
if name.startswith(skip_prefixes):
return []
if name.startswith('model.language_model.'):
name = name.replace('model.language_model.', 'model.')
elif name.startswith('language_model.'):
name = name.replace('language_model.', '')
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("JanusForConditionalGeneration")
class JanusProVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
if "intermediate_size" not in self.hparams_vision:
mlp_ratio = self.hparams_vision.get("mlp_ratio")
hidden_size = self.hparams_vision.get("hidden_size")
if mlp_ratio is not None and hidden_size is not None:
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
if hidden_act == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
"""Map aligner tensors to projector format"""
suffix = ".bias" if name.endswith(".bias") else ".weight"
if name.startswith("model.aligner."):
local_name = name[len("model.aligner."):]
elif name.startswith("aligner."):
local_name = name[len("aligner."):]
else:
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
if local_name.startswith("fc1."):
mm_index = 0
elif local_name.startswith("hidden_layers."):
parts = local_name.split(".", 2)
if len(parts) < 3:
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
mm_index = int(parts[1]) + 1
else:
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
return [(tensor_name, data_torch)]
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# Skip language model tensors as they will be handled by `JanusProModel`
if name.startswith(('model.language_model.', 'language_model.')):
return []
# Skip generation-related components
skip_generation_prefixes = (
'model.vqmodel.',
'vqmodel.',
'model.generation_embeddings.',
'generation_embeddings.',
'model.generation_aligner.',
'generation_aligner.',
'model.generation_head.',
'generation_head.',
)
if name.startswith(skip_generation_prefixes):
return []
# Handle aligner tensors
if name.startswith(('model.aligner.', 'aligner.')):
return list(self._map_aligner_tensor(data_torch, name))
# Handle vision tensors
if name.startswith(('model.vision_model.', 'vision_model.')):
return [(self.map_tensor_name(name), data_torch)]
return []
###### CONVERSION LOGIC ######

View File

@ -7,9 +7,9 @@
## Images
We have three Docker images available for this project:
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`)
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
Additionally, there the following images, similar to the above:

View File

@ -308,6 +308,10 @@ function(ggml_add_cpu_backend_variant tag_name)
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
foreach (feat VXE2 NNPA)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
@ -377,9 +381,8 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()

View File

@ -504,11 +504,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endforeach()
endif()
if (GGML_VXE OR GGML_INTERNAL_VXE)
message(STATUS "VX/VXE/VXE2 enabled")
if (GGML_VXE OR GGML_INTERNAL_VXE2)
message(STATUS "VXE2 enabled")
list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE)
list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)
endif()
if (GGML_INTERNAL_NNPA)
message(STATUS "NNPA enabled")
list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
message(STATUS "Wasm detected")
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)

View File

@ -0,0 +1,50 @@
#include "ggml-backend-impl.h"
#if defined(__s390x__)
#include <sys/auxv.h>
// find hwcap bits in asm/elf.h
#ifndef HWCAP_VXRS_EXT2
#define HWCAP_VXRS_EXT2 (1 << 15)
#endif
#ifndef HWCAP_NNPA
#define HWCAP_NNPA (1 << 20)
#endif
struct s390x_features {
bool has_vxe2 = false;
bool has_nnpa = false;
s390x_features() {
uint32_t hwcap = getauxval(AT_HWCAP);
// NOTE: use hwcap2 with DFLT for z17 and later
// uint32_t hwcap2 = getauxval(AT_HWCAP2);
has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);
has_nnpa = !!(hwcap & HWCAP_NNPA);
}
};
static int ggml_backend_cpu_s390x_score() {
int score = 1;
s390x_features sf;
// IBM z15 / LinuxONE 3
#ifdef GGML_USE_VXE2
if (!sf.has_vxe2) { return 0; }
score += 1 << 1;
#endif
// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5
#ifdef GGML_USE_NNPA
if (!sf.has_nnpa) { return 0; }
score += 1 << 2;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)
#endif // __s390x__

View File

@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_XIELU:
ggml_cuda_op_xielu(ctx, dst);
break;
case GGML_UNARY_OP_FLOOR:
ggml_cuda_op_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_cuda_op_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_cuda_op_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_cuda_op_trunc(ctx, dst);
break;
default:
return false;
}
@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous(op->src[0]);
default:
return false;

View File

@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x);
}
static __device__ __forceinline__ float op_floor(float x) {
return floorf(x);
}
static __device__ __forceinline__ float op_ceil(float x) {
return ceilf(x);
}
static __device__ __forceinline__ float op_round(float x) {
return round(x);
}
static __device__ __forceinline__ float op_trunc(float x) {
return trunc(x);
}
template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_floor>(ctx, dst);
}
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_ceil>(ctx, dst);
}
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_round>(ctx, dst);
}
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_trunc>(ctx, dst);
}
/* gated ops */
template <float (*op)(float), typename T>

View File

@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 72 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
op->src[0]->ne[0] != 112 &&

View File

@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;

View File

@ -3186,6 +3186,7 @@ class VisionProjectorType:
KIMIVL = "kimivl"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
# Items here are (block size, type size)

View File

@ -1183,6 +1183,7 @@ class TensorNameMap:
"model.mm_projector.mlp.mlp.{bid}",
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
"mlp1.{bid}", # InternVL
"model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
),
MODEL_TENSOR.V_MMPROJ_PEG: (
@ -1291,6 +1292,7 @@ class TensorNameMap:
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral

View File

@ -461,7 +461,10 @@ extern "C" {
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API bool llama_supports_rpc (void);
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@ -585,7 +588,7 @@ extern "C" {
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
// NOTE: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
@ -1111,8 +1114,6 @@ extern "C" {
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
// llama_sampler_accept(smpl, id);
// ...
// }
//

74
scripts/bench-models.sh Normal file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
RESULTS="bench-models-results.txt"
: > "$RESULTS"
ARGS_BB="-c 270336 -npp 512,4096,8192 -npl 1,2,4,8,16,32 -ntg 32"
ARGS_B="-d 0,4096,8192,16384,32768 -p 2048 -n 32"
QUICK=0
while (( "$#" )); do
case "$1" in
--quick) QUICK=1; shift ;;
*) shift ;;
esac
done
if (( QUICK )); then
ARGS_BB="-c 20480 -npp 512,4096 -npl 1,2,4 -ntg 32"
ARGS_B="-d 0 -p 2048 -n 32"
fi
run_model() {
local HFR=$1
local HFF=$2
printf "## ${HFR}\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf "Model: https://huggingface.co/${HFR}\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf -- "- \`llama-batched-bench\`\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
./bin/llama-batched-bench \
-hfr "${HFR}" -hff "${HFF}" \
-m "${HFF}" -fa 1 -ub 2048 --no-mmap \
${ARGS_BB} | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf -- "- \`llama-bench\`\n" | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
./bin/llama-bench \
-m "${HFF}" -fa 1 -ub 2048 -mmp 0 \
${ARGS_B} | tee -a "$RESULTS"
printf "\n" | tee -a "$RESULTS"
printf "\n"
}
run_model "ggml-org/gpt-oss-20b-GGUF" "gpt-oss-20b-mxfp4.gguf"
run_model "ggml-org/gpt-oss-120b-GGUF" "gpt-oss-120b-mxfp4-00001-of-00003.gguf"
run_model "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF" "qwen3-coder-30b-a3b-instruct-q8_0.gguf"
run_model "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF" "qwen2.5-coder-7b-q8_0.gguf"
run_model "ggml-org/gemma-3-4b-it-qat-GGUF" "gemma-3-4b-it-qat-Q4_0.gguf"
if [[ -f models-extra.txt ]]; then
while read -r HFR HFF; do
[[ -z "$HFR" ]] && continue
run_model "$HFR" "$HFF"
done < models-extra.txt
fi
printf "\n=====================================\n"
printf "\n"
cat "$RESULTS"
printf "\n"
printf "Done! Results are written to $RESULTS\n"
printf "\n"

View File

@ -112,11 +112,24 @@ llama_context::llama_context(
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.kv_unified) {
cparams.n_ctx_seq = cparams.n_ctx;
} else {
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
if (cparams.n_ctx_seq == 0) {
throw std::runtime_error("n_ctx_seq == 0");
}
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
}
}
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@ -125,14 +138,14 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
if (!hparams.vocab_only) {
@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
uint32_t llama_context::n_ctx_per_seq() const {
return cparams.n_ctx / cparams.n_seq_max;
uint32_t llama_context::n_ctx_seq() const {
return cparams.n_ctx_seq;
}
uint32_t llama_context::n_batch() const {
@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
return ctx->n_ctx_seq();
}
uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}

View File

@ -43,11 +43,11 @@ struct llama_context {
ggml_backend_sched_t get_sched() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_ctx() const;
uint32_t n_ctx_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;

View File

@ -8,6 +8,7 @@
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_ctx_seq; // context for a single sequence
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;

View File

@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
}
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
}
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,

View File

@ -7225,8 +7225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA

View File

@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
}
batch = llama_batch_get_one(&token, 1);
if (llama_decode(ctx.get(), batch)) {
int ret = llama_decode(ctx.get(), batch);
if (ret == 1 && i > 0) {
LOG_INF("Context full, stopping generation.\n");
break;
}
if (ret != 0) {
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
failed.store(true);
return;

View File

@ -221,7 +221,5 @@ int main(int argc, char ** argv) {
llama_backend_free();
LOG("\n\n");
return 0;
}

View File

@ -155,6 +155,7 @@ enum projector_type {
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_UNKNOWN,
};
@ -180,6 +181,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -6,7 +6,6 @@
#include "clip-impl.h"
#include "ggml.h"
#include "ggml-cpp.h"
#include "ggml-cpu.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "gguf.h"
@ -17,17 +16,15 @@
#include <cstring>
#include <fstream>
#include <map>
#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
#include <limits>
#include <array>
#include <numeric>
#include <functional>
// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
enum ffn_op_type {
@ -431,12 +428,14 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
@ -601,6 +600,15 @@ struct clip_graph {
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
} else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) {
cur = build_ffn(cur,
model.mm_0_w, model.mm_0_b,
nullptr, nullptr,
model.mm_1_w, model.mm_1_b,
hparams.ffn_op,
-1);
} else {
GGML_ABORT("SigLIP: Unsupported projector type");
}
@ -1742,7 +1750,6 @@ struct clip_graph {
return gf;
}
// whisper encoder with custom projector
ggml_cgraph * build_whisper_enc() {
const int n_frames = img.nx;
@ -2272,17 +2279,25 @@ private:
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
//cb(k, "v", il);
ggml_tensor * cur;
// TODO @ngxson : support flash attention
{
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else {
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
v = ggml_cont(ctx0, v);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
// const auto n_kv = k->ne[1]; // for flash attention
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// F32 may not needed for vision encoders?
@ -2462,6 +2477,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_kimivl();
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
res = graph.build_siglip();
} break;
case PROJECTOR_TYPE_COGVLM:
{
res = graph.build_cogvlm();
@ -3176,6 +3195,13 @@ struct clip_model_loader {
model.mm_boi = get_tensor(TN_TOK_BOI);
model.mm_eoi = get_tensor(TN_TOK_EOI);
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@ -3217,7 +3243,87 @@ struct clip_model_loader {
}
}
void alloc_compute_meta(clip_ctx & ctx_clip) {
struct support_info_op {
ggml_tensor * op;
// true if the op runs on the accelerated ctx_clip.backend
bool is_accel = true;
};
struct support_info_graph {
// whether the clip_ctx.backend supports flash attention
bool fattn = true;
ggml_tensor * fattn_op = nullptr; // for debugging
std::vector<support_info_op> ops;
};
static void warmup(clip_ctx & ctx_clip) {
support_info_graph info;
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
info = alloc_compute_meta(ctx_clip);
if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
LOG_WRN("%s: op params: \n", __func__);
static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
name, ggml_type_name(t->type),
t->ne[0], t->ne[1], t->ne[2], t->ne[3],
t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
};
print_shape(__func__, " dst", op);
print_shape(__func__, "src0", op->src[0]);
print_shape(__func__, "src1", op->src[1]);
print_shape(__func__, "src2", op->src[2]);
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip);
}
} else {
info = alloc_compute_meta(ctx_clip);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
}
LOG_INF("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
// print ops that are not supported by the GPU backend (if there is one)
if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
std::vector<support_info_op> unsupported_ops;
for (const auto & op : info.ops) {
if (!op.is_accel) {
unsupported_ops.push_back(op);
}
}
if (!unsupported_ops.empty()) {
LOG_WRN("%s: *****************************************************************\n", __func__);
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
for (const auto & op : unsupported_ops) {
LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
ggml_op_name(op.op->op),
ggml_type_name(op.op->type),
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
}
LOG_WRN("%s: flash attention is %s\n", __func__,
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
}
}
}
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
@ -3248,57 +3354,95 @@ struct clip_model_loader {
size / 1024.0 / 1024.0);
}
}
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
const int n_nodes = ggml_graph_n_nodes(gf);
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
support_info_graph res {
/*.fattn = */ true,
/*.fattn_op = */ nullptr,
/*.ops = */ {},
};
// check op support
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * node = ggml_graph_node(gf, i);
res.ops.push_back({node, true});
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
res.ops.back().is_accel = false;
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
res.fattn = false;
res.fattn_op = node;
}
}
}
return res;
}
void get_bool(const std::string & key, bool & output, bool required = true) {
void get_bool(const std::string & key, bool & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_bool(ctx_gguf.get(), i);
}
void get_i32(const std::string & key, int & output, bool required = true) {
void get_i32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_i32(ctx_gguf.get(), i);
}
void get_u32(const std::string & key, int & output, bool required = true) {
void get_u32(const std::string & key, int & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_u32(ctx_gguf.get(), i);
}
void get_f32(const std::string & key, float & output, bool required = true) {
void get_f32(const std::string & key, float & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = gguf_get_val_f32(ctx_gguf.get(), i);
}
void get_string(const std::string & key, std::string & output, bool required = true) {
void get_string(const std::string & key, std::string & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
}
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) throw std::runtime_error("Key not found: " + key);
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
int n = gguf_get_arr_n(ctx_gguf.get(), i);
@ -3309,7 +3453,7 @@ struct clip_model_loader {
}
}
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
auto & hparams = model.hparams;
for (int x = 1; x <= max_patches_per_side; x++) {
for (int y = 1; y <= max_patches_per_side; y++) {
@ -3337,24 +3481,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
ctx_vision = new clip_ctx(ctx_params);
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
loader.load_tensors(*ctx_vision);
loader.alloc_compute_meta(*ctx_vision);
loader.warmup(*ctx_vision);
}
if (loader.has_audio) {
ctx_audio = new clip_ctx(ctx_params);
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
loader.load_tensors(*ctx_audio);
loader.alloc_compute_meta(*ctx_audio);
loader.warmup(*ctx_audio);
}
} catch (const std::exception & e) {
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
if (ctx_vision) {
delete ctx_vision;
}
if (ctx_audio) {
delete ctx_audio;
}
delete ctx_vision;
delete ctx_audio;
return {nullptr, nullptr};
}
@ -3392,10 +3534,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
}
delete load_image_size;
}
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
return batch->entries.size();
@ -4121,6 +4263,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
// Janus Pro preprocessing: pad to square with gray(127), resize to 384x384
const std::array<uint8_t, 3> pad_color = {127, 127, 127};
clip_image_u8 resized_image;
int sz = params.image_size;
img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(img_f32));
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
{
@ -4297,6 +4451,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
switch (proj) {
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_JANUS_PRO:
{
// do nothing
} break;
@ -4807,6 +4962,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_JANUS_PRO:
case PROJECTOR_TYPE_COGVLM:
{
// do nothing
@ -4895,6 +5051,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_JANUS_PRO:
return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_QWEN3VL:
// main path + deepstack paths

View File

@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include <stddef.h>
#include <stdint.h>
@ -22,9 +23,16 @@ enum clip_modality {
CLIP_MODALITY_AUDIO,
};
enum clip_flash_attn_type {
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
};
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
enum clip_flash_attn_type flash_attn_type;
int image_min_tokens;
int image_max_tokens;
};

View File

@ -136,6 +136,7 @@ struct mtmd_cli_context {
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params.flash_attn_type;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));

View File

@ -19,7 +19,6 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <vector>
// represents raw image data, layout is RGBRGBRGB...
@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
return "<__media__>";
}
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
switch (flash_attn_type) {
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
}
return CLIP_FLASH_ATTN_TYPE_AUTO;
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() {
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
params.image_min_tokens = -1;
params.image_max_tokens = -1;
return params;
@ -164,11 +173,13 @@ struct mtmd_context {
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
// custom image token limits
ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
auto res = clip_init(mmproj_fname, ctx_clip_params);
ctx_v = res.ctx_v;
ctx_a = res.ctx_a;
@ -383,9 +394,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
}
void mtmd_free(mtmd_context * ctx) {
if (ctx) {
delete ctx;
}
delete ctx;
}
struct mtmd_tokenizer {

View File

@ -82,6 +82,7 @@ struct mtmd_context_params {
enum ggml_log_level verbosity;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
enum llama_flash_attn_type flash_attn_type;
// limit number of image tokens, only for vision models with dynamic resolution
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)

Binary file not shown.

View File

@ -2407,7 +2407,7 @@ struct server_context {
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@ -2456,6 +2456,7 @@ struct server_context {
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params_base.flash_attn_type;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
@ -2497,10 +2498,16 @@ struct server_context {
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model);
int n_ctx_slot = llama_n_ctx_seq(ctx);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -2529,7 +2536,7 @@ struct server_context {
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
@ -2701,6 +2708,39 @@ struct server_context {
return ret;
}
// return true if at least one slot has been purged
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool res = false;
if (!params_base.kv_unified) {
return res;
}
for (auto & slot : slots) {
if (slot.is_processing()) {
continue;
}
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
res = true;
// purge slots one by one
break;
}
}
return res;
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
@ -3637,9 +3677,10 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
@ -3916,8 +3957,11 @@ struct server_context {
// truncate any tokens that are beyond n_past for this slot
const llama_pos p0 = slot.prompt.tokens.pos_next();
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
@ -3926,8 +3970,6 @@ struct server_context {
slot.prompt.tokens.clear();
}
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
@ -4128,6 +4170,8 @@ struct server_context {
std::string err;
if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}
@ -4143,17 +4187,23 @@ struct server_context {
// TODO: handle ret == 2 (abort) when we start aborting
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
send_error(slot, err);
slot.release();
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
}
}
break;
}
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
if (!try_purge_idle_slots()) {
n_batch /= 2;
}
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
@ -4393,6 +4443,15 @@ int main(int argc, char ** argv) {
return 1;
}
// TODO: should we have a separate n_parallel parameter for the server?
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
if (params.n_parallel == 1 && params.kv_unified == false) {
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
params.n_parallel = 4;
params.kv_unified = true;
}
common_init();
// struct that contains llama context and inference
@ -4946,7 +5005,7 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();

View File

@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[
(64, 15, False),
(64, 3, False),
(64, 1, True),
]
)
def test_return_progresssss(n_batch, batch_count, reuse_cache):
def test_return_progress(n_batch, batch_count, reuse_cache):
global server
server.n_batch = n_batch
server.n_ctx = 2048
server.n_ctx = 256
server.n_slots = 1
server.start()
def make_cmpl_request():
return server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [
{"role": "user", "content": "This is a test" * 100},
{"role": "user", "content": "This is a test" * 10},
],
"stream": True,
"return_progress": True,

View File

@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
# assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize(
"n_ctx,n_slots,n_predict_vals,expected_success",
[
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
],
)
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
global server
server.n_slots = n_slots
server.kv_unified = True
server.n_ctx = n_ctx
server.start()
prompt = "A"
tasks = []
for n_predict in n_predict_vals:
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
results = parallel_function_calls(tasks)
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
if expect_ok:
assert res.status_code == 200
assert "content" in res.body
if "timings" in res.body:
assert res.body["timings"]["predicted_n"] == n_predict
else:
assert res.status_code == 500
assert "content" not in res.body
@pytest.mark.parametrize(
"prompt,n_predict,response_fields",
[

View File

@ -18,7 +18,7 @@ def test_infill_without_input_extra():
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
def test_infill_with_input_extra():
@ -34,7 +34,7 @@ def test_infill_with_input_extra():
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Dad|excited|park)+", res.body["content"])
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [

View File

@ -78,6 +78,7 @@ class ServerProcess:
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
kv_unified: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
@ -159,6 +160,8 @@ class ServerProcess:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.kv_unified:
server_args.append("--kv-unified")
if self.server_slots:
server_args.append("--slots")
else:

View File

@ -1212,7 +1212,7 @@ public:
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
auto * chunk = tokens.map_idx_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
}
}
}
@ -1244,6 +1244,7 @@ public:
}
void clear() {
map_idx_to_media.clear();
tokens.clear();
}

View File

@ -85,8 +85,8 @@
let displayedModel = $derived((): string | null => {
if (!currentConfig.showModelInfo) return null;
if (currentConfig.modelSelectorEnabled) {
return message.model ?? null;
if (message.model) {
return message.model;
}
return serverModel;

View File

@ -54,6 +54,7 @@ export class ChatService {
onError,
onReasoningChunk,
onModel,
onFirstValidChunk,
// Generation parameters
temperature,
max_tokens,
@ -201,6 +202,7 @@ export class ChatService {
onError,
onReasoningChunk,
onModel,
onFirstValidChunk,
conversationId,
abortController.signal
);
@ -267,6 +269,7 @@ export class ChatService {
onError?: (error: Error) => void,
onReasoningChunk?: (chunk: string) => void,
onModel?: (model: string) => void,
onFirstValidChunk?: () => void,
conversationId?: string,
abortSignal?: AbortSignal
): Promise<void> {
@ -283,6 +286,7 @@ export class ChatService {
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
let modelEmitted = false;
let firstValidChunkEmitted = false;
try {
let chunk = '';
@ -311,10 +315,12 @@ export class ChatService {
try {
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
const chunkModel = this.extractModelName(parsed);
if (chunkModel && !modelEmitted) {
modelEmitted = true;
onModel?.(chunkModel);
if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') {
firstValidChunkEmitted = true;
if (!abortSignal?.aborted) {
onFirstValidChunk?.();
}
}
const content = parsed.choices[0]?.delta?.content;
@ -322,6 +328,12 @@ export class ChatService {
const timings = parsed.timings;
const promptProgress = parsed.prompt_progress;
const chunkModel = this.extractModelName(parsed);
if (chunkModel && !modelEmitted) {
modelEmitted = true;
onModel?.(chunkModel);
}
if (timings || promptProgress) {
this.updateProcessingState(timings, promptProgress, conversationId);
if (timings) {

View File

@ -1,6 +1,7 @@
import { DatabaseStore } from '$lib/stores/database';
import { chatService, slotsService } from '$lib/services';
import { config } from '$lib/stores/settings.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { normalizeModelName } from '$lib/utils/model-names';
import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching';
import { browser } from '$app/environment';
@ -362,9 +363,41 @@ class ChatStore {
let resolvedModel: string | null = null;
let modelPersisted = false;
const currentConfig = config();
const preferServerPropsModel = !currentConfig.modelSelectorEnabled;
let serverPropsRefreshed = false;
let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null;
const recordModel = (modelName: string, persistImmediately = true): void => {
const normalizedModel = normalizeModelName(modelName);
const refreshServerPropsOnce = () => {
if (serverPropsRefreshed) {
return;
}
serverPropsRefreshed = true;
const hasExistingProps = serverStore.serverProps !== null;
serverStore
.fetchServerProps({ silent: hasExistingProps })
.then(() => {
updateModelFromServerProps?.(true);
})
.catch((error) => {
console.warn('Failed to refresh server props after streaming started:', error);
});
};
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
const serverModelName = serverStore.modelName;
const preferredModelSource = preferServerPropsModel
? (serverModelName ?? modelName ?? null)
: (modelName ?? serverModelName ?? null);
if (!preferredModelSource) {
return;
}
const normalizedModel = normalizeModelName(preferredModelSource);
if (!normalizedModel || normalizedModel === resolvedModel) {
return;
@ -388,6 +421,20 @@ class ChatStore {
}
};
if (preferServerPropsModel) {
updateModelFromServerProps = (persistImmediately = true) => {
const currentServerModel = serverStore.modelName;
if (!currentServerModel) {
return;
}
recordModel(currentServerModel, persistImmediately);
};
updateModelFromServerProps(false);
}
slotsService.startStreaming();
slotsService.setActiveConversation(assistantMessage.convId);
@ -396,6 +443,9 @@ class ChatStore {
{
...this.getApiOptions(),
onFirstValidChunk: () => {
refreshServerPropsOnce();
},
onChunk: (chunk: string) => {
streamedContent += chunk;
this.setConversationStreaming(

View File

@ -52,6 +52,7 @@ class ServerStore {
private _error = $state<string | null>(null);
private _serverWarning = $state<string | null>(null);
private _slotsEndpointAvailable = $state<boolean | null>(null);
private fetchServerPropsPromise: Promise<void> | null = null;
private readCachedServerProps(): ApiLlamaCppServerProps | null {
if (!browser) return null;
@ -171,73 +172,65 @@ class ServerStore {
/**
* Fetches server properties from the server
*/
async fetchServerProps(): Promise<void> {
this._loading = true;
this._error = null;
this._serverWarning = null;
async fetchServerProps(options: { silent?: boolean } = {}): Promise<void> {
const { silent = false } = options;
const isSilent = silent && this._serverProps !== null;
try {
console.log('Fetching server properties...');
const props = await ChatService.getServerProps();
this._serverProps = props;
this.persistServerProps(props);
console.log('Server properties loaded:', props);
if (this.fetchServerPropsPromise) {
return this.fetchServerPropsPromise;
}
// Check slots endpoint availability after server props are loaded
await this.checkSlotsEndpointAvailability();
} catch (error) {
const hadCachedProps = this._serverProps !== null;
let errorMessage = 'Failed to connect to server';
let isOfflineLikeError = false;
let isServerSideError = false;
if (!isSilent) {
this._loading = true;
this._error = null;
this._serverWarning = null;
}
if (error instanceof Error) {
// Handle specific error types with user-friendly messages
if (error.name === 'TypeError' && error.message.includes('fetch')) {
errorMessage = 'Server is not running or unreachable';
isOfflineLikeError = true;
} else if (error.message.includes('ECONNREFUSED')) {
errorMessage = 'Connection refused - server may be offline';
isOfflineLikeError = true;
} else if (error.message.includes('ENOTFOUND')) {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (error.message.includes('ETIMEDOUT')) {
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (error.message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';
isServerSideError = true;
} else if (error.message.includes('500')) {
errorMessage = 'Server error - check server logs';
isServerSideError = true;
} else if (error.message.includes('404')) {
errorMessage = 'Server endpoint not found';
} else if (error.message.includes('403') || error.message.includes('401')) {
errorMessage = 'Access denied';
const hadProps = this._serverProps !== null;
const fetchPromise = (async () => {
try {
const props = await ChatService.getServerProps();
this._serverProps = props;
this.persistServerProps(props);
this._error = null;
this._serverWarning = null;
await this.checkSlotsEndpointAvailability();
} catch (error) {
if (isSilent && hadProps) {
console.warn('Silent server props refresh failed, keeping cached data:', error);
return;
}
this.handleFetchServerPropsError(error, hadProps);
} finally {
if (!isSilent) {
this._loading = false;
}
this.fetchServerPropsPromise = null;
}
})();
let cachedProps: ApiLlamaCppServerProps | null = null;
this.fetchServerPropsPromise = fetchPromise;
if (!hadCachedProps) {
cachedProps = this.readCachedServerProps();
if (cachedProps) {
this._serverProps = cachedProps;
this._error = null;
await fetchPromise;
}
if (isOfflineLikeError || isServerSideError) {
this._serverWarning = errorMessage;
}
/**
* Handles fetch failures by attempting to recover cached server props and
* updating the user-facing error or warning state appropriately.
*/
private handleFetchServerPropsError(error: unknown, hadProps: boolean): void {
const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error);
console.warn(
'Failed to refresh server properties, using cached values from localStorage:',
errorMessage
);
} else {
this._error = errorMessage;
}
} else {
let cachedProps: ApiLlamaCppServerProps | null = null;
if (!hadProps) {
cachedProps = this.readCachedServerProps();
if (cachedProps) {
this._serverProps = cachedProps;
this._error = null;
if (isOfflineLikeError || isServerSideError) {
@ -245,14 +238,66 @@ class ServerStore {
}
console.warn(
'Failed to refresh server properties, continuing with cached values:',
'Failed to refresh server properties, using cached values from localStorage:',
errorMessage
);
} else {
this._error = errorMessage;
}
console.error('Error fetching server properties:', error);
} finally {
this._loading = false;
} else {
this._error = null;
if (isOfflineLikeError || isServerSideError) {
this._serverWarning = errorMessage;
}
console.warn(
'Failed to refresh server properties, continuing with cached values:',
errorMessage
);
}
console.error('Error fetching server properties:', error);
}
private normalizeFetchError(error: unknown): {
errorMessage: string;
isOfflineLikeError: boolean;
isServerSideError: boolean;
} {
let errorMessage = 'Failed to connect to server';
let isOfflineLikeError = false;
let isServerSideError = false;
if (error instanceof Error) {
const message = error.message || '';
if (error.name === 'TypeError' && message.includes('fetch')) {
errorMessage = 'Server is not running or unreachable';
isOfflineLikeError = true;
} else if (message.includes('ECONNREFUSED')) {
errorMessage = 'Connection refused - server may be offline';
isOfflineLikeError = true;
} else if (message.includes('ENOTFOUND')) {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (message.includes('ETIMEDOUT')) {
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';
isServerSideError = true;
} else if (message.includes('500')) {
errorMessage = 'Server error - check server logs';
isServerSideError = true;
} else if (message.includes('404')) {
errorMessage = 'Server endpoint not found';
} else if (message.includes('403') || message.includes('401')) {
errorMessage = 'Access denied';
}
}
return { errorMessage, isOfflineLikeError, isServerSideError };
}
/**
@ -264,6 +309,7 @@ class ServerStore {
this._serverWarning = null;
this._loading = false;
this._slotsEndpointAvailable = null;
this.fetchServerPropsPromise = null;
this.persistServerProps(null);
}
}

View File

@ -186,6 +186,7 @@ export interface ApiChatCompletionRequest {
}
export interface ApiChatCompletionStreamChunk {
object?: string;
model?: string;
choices: Array<{
model?: string;

View File

@ -42,6 +42,7 @@ export interface SettingsChatServiceOptions {
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
onModel?: (model: string) => void;
onFirstValidChunk?: () => void;
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
onError?: (error: Error) => void;
}