Merge branch 'master' into xsn/mtmd_custom_min_max_tokens
This commit is contained in:
commit
79b98dbf96
|
|
@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DLLAMA_BUILD_TESTS=OFF \
|
-DLLAMA_BUILD_TESTS=OFF \
|
||||||
-DGGML_BACKEND_DL=OFF \
|
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DGGML_BACKEND_DL=ON \
|
||||||
|
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||||
-DGGML_BLAS=ON \
|
-DGGML_BLAS=ON \
|
||||||
-DGGML_BLAS_VENDOR=OpenBLAS && \
|
-DGGML_BLAS_VENDOR=OpenBLAS && \
|
||||||
cmake --build build --config Release -j $(nproc) && \
|
cmake --build build --config Release -j $(nproc) && \
|
||||||
|
|
@ -103,6 +104,7 @@ FROM base AS light
|
||||||
WORKDIR /llama.cpp/bin
|
WORKDIR /llama.cpp/bin
|
||||||
|
|
||||||
# Copy llama.cpp binaries and libraries
|
# Copy llama.cpp binaries and libraries
|
||||||
|
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
|
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
|
||||||
|
|
||||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||||
|
|
@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0
|
||||||
WORKDIR /llama.cpp/bin
|
WORKDIR /llama.cpp/bin
|
||||||
|
|
||||||
# Copy llama.cpp binaries and libraries
|
# Copy llama.cpp binaries and libraries
|
||||||
|
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
|
||||||
|
|
@ -134,8 +134,8 @@ jobs:
|
||||||
include:
|
include:
|
||||||
- build: 'x64'
|
- build: 'x64'
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
- build: 's390x'
|
||||||
os: ubuntu-22.04-s390x
|
os: ubuntu-24.04-s390x
|
||||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||||
# - build: 'arm64'
|
# - build: 'arm64'
|
||||||
# os: ubuntu-22.04-arm
|
# os: ubuntu-22.04-arm
|
||||||
|
|
|
||||||
|
|
@ -313,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
||||||
}
|
}
|
||||||
if (!msg.reasoning_content.empty()) {
|
if (!msg.reasoning_content.empty()) {
|
||||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||||
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
|
|
||||||
}
|
}
|
||||||
if (!msg.tool_name.empty()) {
|
if (!msg.tool_name.empty()) {
|
||||||
jmsg["name"] = msg.tool_name;
|
jmsg["name"] = msg.tool_name;
|
||||||
|
|
@ -1810,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
auto prompt = apply(tmpl, inputs);
|
|
||||||
|
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||||
|
auto adjusted_messages = json::array();
|
||||||
|
for (const auto & msg : inputs.messages) {
|
||||||
|
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||||
|
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||||
|
|
||||||
|
if (has_reasoning_content && has_tool_calls) {
|
||||||
|
auto adjusted_message = msg;
|
||||||
|
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||||
|
adjusted_messages.push_back(adjusted_message);
|
||||||
|
} else {
|
||||||
|
adjusted_messages.push_back(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||||
|
|
||||||
// Check if we need to replace the return token with end token during
|
// Check if we need to replace the return token with end token during
|
||||||
// inference and without generation prompt. For more details see:
|
// inference and without generation prompt. For more details see:
|
||||||
|
|
|
||||||
|
|
@ -9802,6 +9802,113 @@ class CogVLMModel(LlamaModel):
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("JanusForConditionalGeneration")
|
||||||
|
class JanusProModel(LlamaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# Skip vision, aligner, and generation tensors
|
||||||
|
skip_prefixes = (
|
||||||
|
'model.vision_model.',
|
||||||
|
'model.aligner.',
|
||||||
|
'model.vqmodel.',
|
||||||
|
'model.generation_embeddings.',
|
||||||
|
'model.generation_aligner.',
|
||||||
|
'model.generation_head.',
|
||||||
|
)
|
||||||
|
if name.startswith(skip_prefixes):
|
||||||
|
return []
|
||||||
|
|
||||||
|
if name.startswith('model.language_model.'):
|
||||||
|
name = name.replace('model.language_model.', 'model.')
|
||||||
|
elif name.startswith('language_model.'):
|
||||||
|
name = name.replace('language_model.', '')
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("JanusForConditionalGeneration")
|
||||||
|
class JanusProVisionModel(MmprojModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.hparams_vision is not None
|
||||||
|
if "intermediate_size" not in self.hparams_vision:
|
||||||
|
mlp_ratio = self.hparams_vision.get("mlp_ratio")
|
||||||
|
hidden_size = self.hparams_vision.get("hidden_size")
|
||||||
|
if mlp_ratio is not None and hidden_size is not None:
|
||||||
|
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
assert self.hparams_vision is not None
|
||||||
|
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
|
||||||
|
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
|
||||||
|
|
||||||
|
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
|
||||||
|
if hidden_act == "gelu":
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
elif hidden_act == "silu":
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
|
||||||
|
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
"""Map aligner tensors to projector format"""
|
||||||
|
suffix = ".bias" if name.endswith(".bias") else ".weight"
|
||||||
|
|
||||||
|
if name.startswith("model.aligner."):
|
||||||
|
local_name = name[len("model.aligner."):]
|
||||||
|
elif name.startswith("aligner."):
|
||||||
|
local_name = name[len("aligner."):]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
|
||||||
|
|
||||||
|
if local_name.startswith("fc1."):
|
||||||
|
mm_index = 0
|
||||||
|
elif local_name.startswith("hidden_layers."):
|
||||||
|
parts = local_name.split(".", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
|
||||||
|
mm_index = int(parts[1]) + 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
|
||||||
|
|
||||||
|
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
|
||||||
|
return [(tensor_name, data_torch)]
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
# Skip language model tensors as they will be handled by `JanusProModel`
|
||||||
|
if name.startswith(('model.language_model.', 'language_model.')):
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Skip generation-related components
|
||||||
|
skip_generation_prefixes = (
|
||||||
|
'model.vqmodel.',
|
||||||
|
'vqmodel.',
|
||||||
|
'model.generation_embeddings.',
|
||||||
|
'generation_embeddings.',
|
||||||
|
'model.generation_aligner.',
|
||||||
|
'generation_aligner.',
|
||||||
|
'model.generation_head.',
|
||||||
|
'generation_head.',
|
||||||
|
)
|
||||||
|
if name.startswith(skip_generation_prefixes):
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Handle aligner tensors
|
||||||
|
if name.startswith(('model.aligner.', 'aligner.')):
|
||||||
|
return list(self._map_aligner_tensor(data_torch, name))
|
||||||
|
|
||||||
|
# Handle vision tensors
|
||||||
|
if name.startswith(('model.vision_model.', 'vision_model.')):
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@
|
||||||
## Images
|
## Images
|
||||||
We have three Docker images available for this project:
|
We have three Docker images available for this project:
|
||||||
|
|
||||||
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
|
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||||
|
|
||||||
Additionally, there the following images, similar to the above:
|
Additionally, there the following images, similar to the above:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -308,6 +308,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
endforeach()
|
endforeach()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
foreach (feat VXE2 NNPA)
|
||||||
|
set(GGML_INTERNAL_${feat} OFF)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
foreach (feat ${ARGN})
|
foreach (feat ${ARGN})
|
||||||
set(GGML_INTERNAL_${feat} ON)
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
@ -377,9 +381,8 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||||
endif()
|
endif()
|
||||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
|
||||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
|
||||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -504,11 +504,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
if (GGML_VXE OR GGML_INTERNAL_VXE2)
|
||||||
message(STATUS "VX/VXE/VXE2 enabled")
|
message(STATUS "VXE2 enabled")
|
||||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (GGML_INTERNAL_NNPA)
|
||||||
|
message(STATUS "NNPA enabled")
|
||||||
|
list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})
|
||||||
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
||||||
message(STATUS "Wasm detected")
|
message(STATUS "Wasm detected")
|
||||||
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
|
#if defined(__s390x__)
|
||||||
|
#include <sys/auxv.h>
|
||||||
|
|
||||||
|
// find hwcap bits in asm/elf.h
|
||||||
|
#ifndef HWCAP_VXRS_EXT2
|
||||||
|
#define HWCAP_VXRS_EXT2 (1 << 15)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef HWCAP_NNPA
|
||||||
|
#define HWCAP_NNPA (1 << 20)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct s390x_features {
|
||||||
|
bool has_vxe2 = false;
|
||||||
|
bool has_nnpa = false;
|
||||||
|
|
||||||
|
s390x_features() {
|
||||||
|
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||||
|
// NOTE: use hwcap2 with DFLT for z17 and later
|
||||||
|
// uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
||||||
|
|
||||||
|
has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);
|
||||||
|
has_nnpa = !!(hwcap & HWCAP_NNPA);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static int ggml_backend_cpu_s390x_score() {
|
||||||
|
int score = 1;
|
||||||
|
s390x_features sf;
|
||||||
|
|
||||||
|
// IBM z15 / LinuxONE 3
|
||||||
|
#ifdef GGML_USE_VXE2
|
||||||
|
if (!sf.has_vxe2) { return 0; }
|
||||||
|
score += 1 << 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5
|
||||||
|
#ifdef GGML_USE_NNPA
|
||||||
|
if (!sf.has_nnpa) { return 0; }
|
||||||
|
score += 1 << 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)
|
||||||
|
|
||||||
|
#endif // __s390x__
|
||||||
|
|
@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_UNARY_OP_XIELU:
|
case GGML_UNARY_OP_XIELU:
|
||||||
ggml_cuda_op_xielu(ctx, dst);
|
ggml_cuda_op_xielu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
ggml_cuda_op_floor(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
ggml_cuda_op_ceil(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
ggml_cuda_op_round(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
|
ggml_cuda_op_trunc(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
case GGML_UNARY_OP_ELU:
|
case GGML_UNARY_OP_ELU:
|
||||||
|
case GGML_UNARY_OP_FLOOR:
|
||||||
|
case GGML_UNARY_OP_CEIL:
|
||||||
|
case GGML_UNARY_OP_ROUND:
|
||||||
|
case GGML_UNARY_OP_TRUNC:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
|
||||||
return (x > 0.f) ? x : expm1f(x);
|
return (x > 0.f) ? x : expm1f(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_floor(float x) {
|
||||||
|
return floorf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_ceil(float x) {
|
||||||
|
return ceilf(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_round(float x) {
|
||||||
|
return round(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_trunc(float x) {
|
||||||
|
return trunc(x);
|
||||||
|
}
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_floor>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_ceil>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_round>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||||
|
}
|
||||||
/* gated ops */
|
/* gated ops */
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
if (op->src[0]->ne[0] != 32 &&
|
if (op->src[0]->ne[0] != 32 &&
|
||||||
op->src[0]->ne[0] != 40 &&
|
op->src[0]->ne[0] != 40 &&
|
||||||
op->src[0]->ne[0] != 64 &&
|
op->src[0]->ne[0] != 64 &&
|
||||||
|
op->src[0]->ne[0] != 72 &&
|
||||||
op->src[0]->ne[0] != 80 &&
|
op->src[0]->ne[0] != 80 &&
|
||||||
op->src[0]->ne[0] != 96 &&
|
op->src[0]->ne[0] != 96 &&
|
||||||
op->src[0]->ne[0] != 112 &&
|
op->src[0]->ne[0] != 112 &&
|
||||||
|
|
|
||||||
|
|
@ -5362,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||||
|
|
@ -5374,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
||||||
|
|
@ -5387,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||||
|
|
@ -5400,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
||||||
|
|
@ -5412,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
||||||
|
|
@ -5424,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
||||||
|
|
@ -5436,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
||||||
|
|
@ -5448,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
||||||
|
|
|
||||||
|
|
@ -3186,6 +3186,7 @@ class VisionProjectorType:
|
||||||
KIMIVL = "kimivl"
|
KIMIVL = "kimivl"
|
||||||
LIGHTONOCR = "lightonocr"
|
LIGHTONOCR = "lightonocr"
|
||||||
COGVLM = "cogvlm"
|
COGVLM = "cogvlm"
|
||||||
|
JANUS_PRO = "janus_pro"
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
|
|
||||||
|
|
@ -1183,6 +1183,7 @@ class TensorNameMap:
|
||||||
"model.mm_projector.mlp.mlp.{bid}",
|
"model.mm_projector.mlp.mlp.{bid}",
|
||||||
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
||||||
"mlp1.{bid}", # InternVL
|
"mlp1.{bid}", # InternVL
|
||||||
|
"model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MMPROJ_PEG: (
|
MODEL_TENSOR.V_MMPROJ_PEG: (
|
||||||
|
|
@ -1291,6 +1292,7 @@ class TensorNameMap:
|
||||||
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
|
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
|
||||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||||
|
"model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
|
||||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
||||||
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
||||||
|
|
|
||||||
|
|
@ -461,7 +461,10 @@ extern "C" {
|
||||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
LLAMA_API bool llama_supports_rpc (void);
|
LLAMA_API bool llama_supports_rpc (void);
|
||||||
|
|
||||||
|
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||||
|
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
|
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
|
|
@ -585,7 +588,7 @@ extern "C" {
|
||||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||||
|
|
||||||
// Manually free a LoRA adapter
|
// Manually free a LoRA adapter
|
||||||
// Note: loaded adapters will be free when the associated model is deleted
|
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||||
|
|
||||||
// Get the invocation tokens if the current lora is an alora
|
// Get the invocation tokens if the current lora is an alora
|
||||||
|
|
@ -1111,8 +1114,6 @@ extern "C" {
|
||||||
// // sample from the logits of the last token in the batch
|
// // sample from the logits of the last token in the batch
|
||||||
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||||
//
|
//
|
||||||
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
|
||||||
// llama_sampler_accept(smpl, id);
|
|
||||||
// ...
|
// ...
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
RESULTS="bench-models-results.txt"
|
||||||
|
: > "$RESULTS"
|
||||||
|
|
||||||
|
ARGS_BB="-c 270336 -npp 512,4096,8192 -npl 1,2,4,8,16,32 -ntg 32"
|
||||||
|
ARGS_B="-d 0,4096,8192,16384,32768 -p 2048 -n 32"
|
||||||
|
|
||||||
|
QUICK=0
|
||||||
|
while (( "$#" )); do
|
||||||
|
case "$1" in
|
||||||
|
--quick) QUICK=1; shift ;;
|
||||||
|
*) shift ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if (( QUICK )); then
|
||||||
|
ARGS_BB="-c 20480 -npp 512,4096 -npl 1,2,4 -ntg 32"
|
||||||
|
ARGS_B="-d 0 -p 2048 -n 32"
|
||||||
|
fi
|
||||||
|
|
||||||
|
run_model() {
|
||||||
|
local HFR=$1
|
||||||
|
local HFF=$2
|
||||||
|
|
||||||
|
printf "## ${HFR}\n" | tee -a "$RESULTS"
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
printf "Model: https://huggingface.co/${HFR}\n" | tee -a "$RESULTS"
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
printf -- "- \`llama-batched-bench\`\n" | tee -a "$RESULTS"
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
./bin/llama-batched-bench \
|
||||||
|
-hfr "${HFR}" -hff "${HFF}" \
|
||||||
|
-m "${HFF}" -fa 1 -ub 2048 --no-mmap \
|
||||||
|
${ARGS_BB} | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
printf -- "- \`llama-bench\`\n" | tee -a "$RESULTS"
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
./bin/llama-bench \
|
||||||
|
-m "${HFF}" -fa 1 -ub 2048 -mmp 0 \
|
||||||
|
${ARGS_B} | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
printf "\n" | tee -a "$RESULTS"
|
||||||
|
|
||||||
|
printf "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_model "ggml-org/gpt-oss-20b-GGUF" "gpt-oss-20b-mxfp4.gguf"
|
||||||
|
run_model "ggml-org/gpt-oss-120b-GGUF" "gpt-oss-120b-mxfp4-00001-of-00003.gguf"
|
||||||
|
run_model "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF" "qwen3-coder-30b-a3b-instruct-q8_0.gguf"
|
||||||
|
run_model "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF" "qwen2.5-coder-7b-q8_0.gguf"
|
||||||
|
run_model "ggml-org/gemma-3-4b-it-qat-GGUF" "gemma-3-4b-it-qat-Q4_0.gguf"
|
||||||
|
|
||||||
|
if [[ -f models-extra.txt ]]; then
|
||||||
|
while read -r HFR HFF; do
|
||||||
|
[[ -z "$HFR" ]] && continue
|
||||||
|
run_model "$HFR" "$HFF"
|
||||||
|
done < models-extra.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
printf "\n=====================================\n"
|
||||||
|
printf "\n"
|
||||||
|
|
||||||
|
cat "$RESULTS"
|
||||||
|
|
||||||
|
printf "\n"
|
||||||
|
printf "Done! Results are written to $RESULTS\n"
|
||||||
|
printf "\n"
|
||||||
|
|
||||||
|
|
@ -112,11 +112,24 @@ llama_context::llama_context(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
if (cparams.kv_unified) {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx;
|
||||||
|
} else {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
|
||||||
|
if (cparams.n_ctx_seq == 0) {
|
||||||
|
throw std::runtime_error("n_ctx_seq == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
||||||
|
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
||||||
|
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||||
|
|
@ -125,14 +138,14 @@ llama_context::llama_context(
|
||||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||||
|
|
||||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
|
|
@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
|
||||||
return cparams.n_ctx;
|
return cparams.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_ctx_per_seq() const {
|
uint32_t llama_context::n_ctx_seq() const {
|
||||||
return cparams.n_ctx / cparams.n_seq_max;
|
return cparams.n_ctx_seq;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_batch() const {
|
uint32_t llama_context::n_batch() const {
|
||||||
|
|
@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
||||||
return ctx->n_ctx();
|
return ctx->n_ctx();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
||||||
|
return ctx->n_ctx_seq();
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_n_batch(const llama_context * ctx) {
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
||||||
return ctx->n_batch();
|
return ctx->n_batch();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ struct llama_context {
|
||||||
ggml_backend_sched_t get_sched() const;
|
ggml_backend_sched_t get_sched() const;
|
||||||
|
|
||||||
uint32_t n_ctx() const;
|
uint32_t n_ctx() const;
|
||||||
uint32_t n_ctx_per_seq() const;
|
uint32_t n_ctx_seq() const;
|
||||||
uint32_t n_batch() const;
|
uint32_t n_batch() const;
|
||||||
uint32_t n_ubatch() const;
|
uint32_t n_ubatch() const;
|
||||||
uint32_t n_seq_max() const;
|
uint32_t n_seq_max() const;
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
|
uint32_t n_ctx_seq; // context for a single sequence
|
||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
uint32_t n_ubatch;
|
uint32_t n_ubatch;
|
||||||
uint32_t n_seq_max;
|
uint32_t n_seq_max;
|
||||||
|
|
|
||||||
|
|
@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
||||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
|
||||||
|
|
||||||
// choose long/short freq factors based on the context size
|
// choose long/short freq factors based on the context size
|
||||||
if (layers[il].rope_freqs != nullptr) {
|
if (layers[il].rope_freqs != nullptr) {
|
||||||
return layers[il].rope_freqs;
|
return layers[il].rope_freqs;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
|
||||||
return layers[il].rope_long;
|
return layers[il].rope_long;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
/* filter_attn */ std::move(filter_attn),
|
/* filter_attn */ std::move(filter_attn),
|
||||||
/* filter_recr */ std::move(filter_recr));
|
/* filter_recr */ std::move(filter_recr));
|
||||||
} else {
|
} else {
|
||||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
|
||||||
|
|
||||||
if (!cparams.kv_unified) {
|
|
||||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||||
|
|
||||||
if (arch == LLM_ARCH_GEMMA3N) {
|
if (arch == LLM_ARCH_GEMMA3N) {
|
||||||
|
|
@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
params.swa_full,
|
params.swa_full,
|
||||||
cparams.kv_unified,
|
cparams.kv_unified,
|
||||||
n_ctx_per_stream,
|
cparams.n_ctx_seq,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
cparams.n_ubatch,
|
cparams.n_ubatch,
|
||||||
1,
|
1,
|
||||||
|
|
@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
!cparams.flash_attn,
|
!cparams.flash_attn,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
cparams.kv_unified,
|
cparams.kv_unified,
|
||||||
n_ctx_per_stream,
|
cparams.n_ctx_seq,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
1,
|
1,
|
||||||
hparams.n_swa,
|
hparams.n_swa,
|
||||||
|
|
|
||||||
|
|
@ -7225,8 +7225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
|
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||||
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
|
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
|
||||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||||
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
||||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
batch = llama_batch_get_one(&token, 1);
|
batch = llama_batch_get_one(&token, 1);
|
||||||
if (llama_decode(ctx.get(), batch)) {
|
|
||||||
|
int ret = llama_decode(ctx.get(), batch);
|
||||||
|
if (ret == 1 && i > 0) {
|
||||||
|
LOG_INF("Context full, stopping generation.\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != 0) {
|
||||||
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
|
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
|
||||||
failed.store(true);
|
failed.store(true);
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -221,7 +221,5 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
LOG("\n\n");
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,7 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_KIMIVL,
|
PROJECTOR_TYPE_KIMIVL,
|
||||||
PROJECTOR_TYPE_LIGHTONOCR,
|
PROJECTOR_TYPE_LIGHTONOCR,
|
||||||
PROJECTOR_TYPE_COGVLM,
|
PROJECTOR_TYPE_COGVLM,
|
||||||
|
PROJECTOR_TYPE_JANUS_PRO,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -180,6 +181,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||||
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
|
||||||
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
|
||||||
|
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
#include "clip-impl.h"
|
#include "clip-impl.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
#include "ggml-cpu.h"
|
|
||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
#include "gguf.h"
|
#include "gguf.h"
|
||||||
|
|
@ -17,17 +16,15 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <regex>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <numeric>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
// TODO: allow to pass callback from user code
|
||||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||||
|
|
||||||
enum ffn_op_type {
|
enum ffn_op_type {
|
||||||
|
|
@ -431,12 +428,14 @@ struct clip_ctx {
|
||||||
|
|
||||||
int max_nodes = 8192;
|
int max_nodes = 8192;
|
||||||
ggml_backend_sched_ptr sched;
|
ggml_backend_sched_ptr sched;
|
||||||
|
clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||||
|
|
||||||
// for debugging
|
// for debugging
|
||||||
bool debug_graph = false;
|
bool debug_graph = false;
|
||||||
std::vector<ggml_tensor *> debug_print_tensors;
|
std::vector<ggml_tensor *> debug_print_tensors;
|
||||||
|
|
||||||
clip_ctx(clip_context_params & ctx_params) {
|
clip_ctx(clip_context_params & ctx_params) {
|
||||||
|
flash_attn_type = ctx_params.flash_attn_type;
|
||||||
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
|
||||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||||
if (!backend_cpu) {
|
if (!backend_cpu) {
|
||||||
|
|
@ -601,6 +600,15 @@ struct clip_graph {
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.mm_2_b);
|
cur = ggml_add(ctx0, cur, model.mm_2_b);
|
||||||
|
|
||||||
|
} else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) {
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.mm_0_w, model.mm_0_b,
|
||||||
|
nullptr, nullptr,
|
||||||
|
model.mm_1_w, model.mm_1_b,
|
||||||
|
hparams.ffn_op,
|
||||||
|
-1);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("SigLIP: Unsupported projector type");
|
GGML_ABORT("SigLIP: Unsupported projector type");
|
||||||
}
|
}
|
||||||
|
|
@ -1742,7 +1750,6 @@ struct clip_graph {
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// whisper encoder with custom projector
|
// whisper encoder with custom projector
|
||||||
ggml_cgraph * build_whisper_enc() {
|
ggml_cgraph * build_whisper_enc() {
|
||||||
const int n_frames = img.nx;
|
const int n_frames = img.nx;
|
||||||
|
|
@ -2272,17 +2279,25 @@ private:
|
||||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||||
//cb(k, "k", il);
|
//cb(k, "k", il);
|
||||||
|
|
||||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
|
||||||
v = ggml_cont(ctx0, v);
|
|
||||||
//cb(k, "v", il);
|
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
|
|
||||||
// TODO @ngxson : support flash attention
|
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
||||||
{
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||||
|
|
||||||
|
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||||
|
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
||||||
|
|
||||||
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
|
||||||
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
|
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||||
|
v = ggml_cont(ctx0, v);
|
||||||
|
|
||||||
const auto n_tokens = q->ne[1];
|
const auto n_tokens = q->ne[1];
|
||||||
const auto n_head = q->ne[2];
|
const auto n_head = q->ne[2];
|
||||||
// const auto n_kv = k->ne[1]; // for flash attention
|
|
||||||
|
|
||||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
// F32 may not needed for vision encoders?
|
// F32 may not needed for vision encoders?
|
||||||
|
|
@ -2462,6 +2477,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
{
|
{
|
||||||
res = graph.build_kimivl();
|
res = graph.build_kimivl();
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
|
{
|
||||||
|
res = graph.build_siglip();
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_COGVLM:
|
case PROJECTOR_TYPE_COGVLM:
|
||||||
{
|
{
|
||||||
res = graph.build_cogvlm();
|
res = graph.build_cogvlm();
|
||||||
|
|
@ -3176,6 +3195,13 @@ struct clip_model_loader {
|
||||||
model.mm_boi = get_tensor(TN_TOK_BOI);
|
model.mm_boi = get_tensor(TN_TOK_BOI);
|
||||||
model.mm_eoi = get_tensor(TN_TOK_EOI);
|
model.mm_eoi = get_tensor(TN_TOK_EOI);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
|
{
|
||||||
|
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||||
|
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||||
|
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||||
|
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown projector type");
|
GGML_ASSERT(false && "unknown projector type");
|
||||||
}
|
}
|
||||||
|
|
@ -3217,7 +3243,87 @@ struct clip_model_loader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void alloc_compute_meta(clip_ctx & ctx_clip) {
|
struct support_info_op {
|
||||||
|
ggml_tensor * op;
|
||||||
|
|
||||||
|
// true if the op runs on the accelerated ctx_clip.backend
|
||||||
|
bool is_accel = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct support_info_graph {
|
||||||
|
// whether the clip_ctx.backend supports flash attention
|
||||||
|
bool fattn = true;
|
||||||
|
ggml_tensor * fattn_op = nullptr; // for debugging
|
||||||
|
|
||||||
|
std::vector<support_info_op> ops;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void warmup(clip_ctx & ctx_clip) {
|
||||||
|
support_info_graph info;
|
||||||
|
|
||||||
|
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
|
||||||
|
// try to enable flash attention to see if it's supported
|
||||||
|
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
|
||||||
|
info = alloc_compute_meta(ctx_clip);
|
||||||
|
if (!info.fattn && info.fattn_op) {
|
||||||
|
auto op = info.fattn_op;
|
||||||
|
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||||
|
LOG_WRN("%s: WARNING: flash attention not supported by %s, memory usage will increase\n", __func__, ggml_backend_name(ctx_clip.backend));
|
||||||
|
LOG_WRN("%s: op params: \n", __func__);
|
||||||
|
static auto print_shape = [](const char * fn, const char * name, ggml_tensor * t) {
|
||||||
|
LOG_WRN("%s: %s: type = %s, ne = [%d %d %d %d], nb = [%d %d %d %d]\n", fn,
|
||||||
|
name, ggml_type_name(t->type),
|
||||||
|
t->ne[0], t->ne[1], t->ne[2], t->ne[3],
|
||||||
|
t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
|
||||||
|
};
|
||||||
|
print_shape(__func__, " dst", op);
|
||||||
|
print_shape(__func__, "src0", op->src[0]);
|
||||||
|
print_shape(__func__, "src1", op->src[1]);
|
||||||
|
print_shape(__func__, "src2", op->src[2]);
|
||||||
|
LOG_WRN("%s: please report this on github as an issue\n", __func__);
|
||||||
|
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||||
|
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
|
||||||
|
alloc_compute_meta(ctx_clip);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info = alloc_compute_meta(ctx_clip);
|
||||||
|
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
|
||||||
|
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: flash attention is %s\n", __func__,
|
||||||
|
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||||
|
|
||||||
|
// print ops that are not supported by the GPU backend (if there is one)
|
||||||
|
if (ctx_clip.backend && ctx_clip.backend != ctx_clip.backend_cpu) {
|
||||||
|
std::vector<support_info_op> unsupported_ops;
|
||||||
|
for (const auto & op : info.ops) {
|
||||||
|
if (!op.is_accel) {
|
||||||
|
unsupported_ops.push_back(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!unsupported_ops.empty()) {
|
||||||
|
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||||
|
LOG_WRN("%s: WARNING: the CLIP graph uses unsupported operators by the backend\n", __func__);
|
||||||
|
LOG_WRN("%s: the performance will be suboptimal \n", __func__);
|
||||||
|
LOG_WRN("%s: list of unsupported ops (backend=%s):\n", __func__, ggml_backend_name(ctx_clip.backend));
|
||||||
|
for (const auto & op : unsupported_ops) {
|
||||||
|
LOG_WRN("%s: %16s: type = %s, ne = [%d %d %d %d]\n", __func__,
|
||||||
|
ggml_op_name(op.op->op),
|
||||||
|
ggml_type_name(op.op->type),
|
||||||
|
op.op->ne[0], op.op->ne[1], op.op->ne[2], op.op->ne[3]);
|
||||||
|
}
|
||||||
|
LOG_WRN("%s: flash attention is %s\n", __func__,
|
||||||
|
(ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) ? "enabled" : "disabled");
|
||||||
|
LOG_WRN("%s: please report this on github as an issue\n", __func__);
|
||||||
|
LOG_WRN("%s: ref: https://github.com/ggml-org/llama.cpp/pull/16837#issuecomment-3461676118\n", __func__);
|
||||||
|
LOG_WRN("%s: *****************************************************************\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
|
||||||
const auto & hparams = ctx_clip.model.hparams;
|
const auto & hparams = ctx_clip.model.hparams;
|
||||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||||
|
|
||||||
|
|
@ -3248,57 +3354,95 @@ struct clip_model_loader {
|
||||||
size / 1024.0 / 1024.0);
|
size / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
|
||||||
|
const int n_nodes = ggml_graph_n_nodes(gf);
|
||||||
|
|
||||||
|
LOG_INF("%s: graph splits = %d, nodes = %d\n", __func__, n_splits, n_nodes);
|
||||||
|
|
||||||
|
support_info_graph res {
|
||||||
|
/*.fattn = */ true,
|
||||||
|
/*.fattn_op = */ nullptr,
|
||||||
|
/*.ops = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// check op support
|
||||||
|
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||||
|
ggml_tensor * node = ggml_graph_node(gf, i);
|
||||||
|
res.ops.push_back({node, true});
|
||||||
|
if (!ggml_backend_supports_op(ctx_clip.backend, node)) {
|
||||||
|
res.ops.back().is_accel = false;
|
||||||
|
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
res.fattn = false;
|
||||||
|
res.fattn_op = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_bool(const std::string & key, bool & output, bool required = true) {
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_bool(const std::string & key, bool & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output = gguf_get_val_bool(ctx_gguf.get(), i);
|
output = gguf_get_val_bool(ctx_gguf.get(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_i32(const std::string & key, int & output, bool required = true) {
|
void get_i32(const std::string & key, int & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output = gguf_get_val_i32(ctx_gguf.get(), i);
|
output = gguf_get_val_i32(ctx_gguf.get(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_u32(const std::string & key, int & output, bool required = true) {
|
void get_u32(const std::string & key, int & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output = gguf_get_val_u32(ctx_gguf.get(), i);
|
output = gguf_get_val_u32(ctx_gguf.get(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_f32(const std::string & key, float & output, bool required = true) {
|
void get_f32(const std::string & key, float & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output = gguf_get_val_f32(ctx_gguf.get(), i);
|
output = gguf_get_val_f32(ctx_gguf.get(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_string(const std::string & key, std::string & output, bool required = true) {
|
void get_string(const std::string & key, std::string & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
|
output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) {
|
void get_arr_int(const std::string & key, std::vector<int> & output, bool required = true) const {
|
||||||
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||||
if (i < 0) {
|
if (i < 0) {
|
||||||
if (required) throw std::runtime_error("Key not found: " + key);
|
if (required) {
|
||||||
|
throw std::runtime_error("Key not found: " + key);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int n = gguf_get_arr_n(ctx_gguf.get(), i);
|
int n = gguf_get_arr_n(ctx_gguf.get(), i);
|
||||||
|
|
@ -3309,7 +3453,7 @@ struct clip_model_loader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
|
static void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
for (int x = 1; x <= max_patches_per_side; x++) {
|
for (int x = 1; x <= max_patches_per_side; x++) {
|
||||||
for (int y = 1; y <= max_patches_per_side; y++) {
|
for (int y = 1; y <= max_patches_per_side; y++) {
|
||||||
|
|
@ -3337,24 +3481,22 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||||
ctx_vision = new clip_ctx(ctx_params);
|
ctx_vision = new clip_ctx(ctx_params);
|
||||||
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
|
||||||
loader.load_tensors(*ctx_vision);
|
loader.load_tensors(*ctx_vision);
|
||||||
loader.alloc_compute_meta(*ctx_vision);
|
loader.warmup(*ctx_vision);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loader.has_audio) {
|
if (loader.has_audio) {
|
||||||
ctx_audio = new clip_ctx(ctx_params);
|
ctx_audio = new clip_ctx(ctx_params);
|
||||||
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
|
||||||
loader.load_tensors(*ctx_audio);
|
loader.load_tensors(*ctx_audio);
|
||||||
loader.alloc_compute_meta(*ctx_audio);
|
loader.warmup(*ctx_audio);
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
|
LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
|
||||||
if (ctx_vision) {
|
|
||||||
delete ctx_vision;
|
delete ctx_vision;
|
||||||
}
|
|
||||||
if (ctx_audio) {
|
|
||||||
delete ctx_audio;
|
delete ctx_audio;
|
||||||
}
|
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3392,10 +3534,10 @@ void clip_image_size_free(struct clip_image_size * load_image_size) {
|
||||||
}
|
}
|
||||||
delete load_image_size;
|
delete load_image_size;
|
||||||
}
|
}
|
||||||
void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; }
|
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
|
||||||
void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
|
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
|
||||||
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
|
void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { delete batch; }
|
||||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
|
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { delete batch; }
|
||||||
|
|
||||||
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
|
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
|
||||||
return batch->entries.size();
|
return batch->entries.size();
|
||||||
|
|
@ -4121,6 +4263,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||||
res_imgs->entries.push_back(std::move(img_f32));
|
res_imgs->entries.push_back(std::move(img_f32));
|
||||||
} break;
|
} break;
|
||||||
|
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
|
{
|
||||||
|
// Janus Pro preprocessing: pad to square with gray(127), resize to 384x384
|
||||||
|
const std::array<uint8_t, 3> pad_color = {127, 127, 127};
|
||||||
|
clip_image_u8 resized_image;
|
||||||
|
int sz = params.image_size;
|
||||||
|
img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
|
||||||
|
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||||
|
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
|
||||||
|
res_imgs->entries.push_back(std::move(img_f32));
|
||||||
|
} break;
|
||||||
|
|
||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||||
{
|
{
|
||||||
|
|
@ -4297,6 +4451,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||||
switch (proj) {
|
switch (proj) {
|
||||||
case PROJECTOR_TYPE_MLP:
|
case PROJECTOR_TYPE_MLP:
|
||||||
case PROJECTOR_TYPE_MLP_NORM:
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
{
|
{
|
||||||
// do nothing
|
// do nothing
|
||||||
} break;
|
} break;
|
||||||
|
|
@ -4807,6 +4962,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
case PROJECTOR_TYPE_LFM2:
|
case PROJECTOR_TYPE_LFM2:
|
||||||
case PROJECTOR_TYPE_VOXTRAL:
|
case PROJECTOR_TYPE_VOXTRAL:
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
case PROJECTOR_TYPE_COGVLM:
|
case PROJECTOR_TYPE_COGVLM:
|
||||||
{
|
{
|
||||||
// do nothing
|
// do nothing
|
||||||
|
|
@ -4895,6 +5051,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
return ctx->model.mm_model_mlp_3_w->ne[1];
|
return ctx->model.mm_model_mlp_3_w->ne[1];
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
case PROJECTOR_TYPE_JANUS_PRO:
|
||||||
return ctx->model.mm_1_b->ne[0];
|
return ctx->model.mm_1_b->ne[0];
|
||||||
case PROJECTOR_TYPE_QWEN3VL:
|
case PROJECTOR_TYPE_QWEN3VL:
|
||||||
// main path + deepstack paths
|
// main path + deepstack paths
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
|
@ -22,9 +23,16 @@ enum clip_modality {
|
||||||
CLIP_MODALITY_AUDIO,
|
CLIP_MODALITY_AUDIO,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum clip_flash_attn_type {
|
||||||
|
CLIP_FLASH_ATTN_TYPE_AUTO = -1,
|
||||||
|
CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
|
||||||
|
CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
|
||||||
|
};
|
||||||
|
|
||||||
struct clip_context_params {
|
struct clip_context_params {
|
||||||
bool use_gpu;
|
bool use_gpu;
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
|
enum clip_flash_attn_type flash_attn_type;
|
||||||
int image_min_tokens;
|
int image_min_tokens;
|
||||||
int image_max_tokens;
|
int image_max_tokens;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ struct mtmd_cli_context {
|
||||||
mparams.print_timings = true;
|
mparams.print_timings = true;
|
||||||
mparams.n_threads = params.cpuparams.n_threads;
|
mparams.n_threads = params.cpuparams.n_threads;
|
||||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||||
|
mparams.flash_attn_type = params.flash_attn_type;
|
||||||
mparams.image_min_tokens = params.image_min_tokens;
|
mparams.image_min_tokens = params.image_min_tokens;
|
||||||
mparams.image_max_tokens = params.image_max_tokens;
|
mparams.image_max_tokens = params.image_max_tokens;
|
||||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// represents raw image data, layout is RGBRGBRGB...
|
// represents raw image data, layout is RGBRGBRGB...
|
||||||
|
|
@ -92,6 +91,15 @@ const char * mtmd_default_marker() {
|
||||||
return "<__media__>";
|
return "<__media__>";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_type flash_attn_type) {
|
||||||
|
switch (flash_attn_type) {
|
||||||
|
case LLAMA_FLASH_ATTN_TYPE_AUTO: return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||||
|
case LLAMA_FLASH_ATTN_TYPE_DISABLED: return CLIP_FLASH_ATTN_TYPE_DISABLED;
|
||||||
|
case LLAMA_FLASH_ATTN_TYPE_ENABLED: return CLIP_FLASH_ATTN_TYPE_ENABLED;
|
||||||
|
}
|
||||||
|
return CLIP_FLASH_ATTN_TYPE_AUTO;
|
||||||
|
}
|
||||||
|
|
||||||
mtmd_context_params mtmd_context_params_default() {
|
mtmd_context_params mtmd_context_params_default() {
|
||||||
mtmd_context_params params;
|
mtmd_context_params params;
|
||||||
params.use_gpu = true;
|
params.use_gpu = true;
|
||||||
|
|
@ -100,6 +108,7 @@ mtmd_context_params mtmd_context_params_default() {
|
||||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||||
params.media_marker = mtmd_default_marker();
|
params.media_marker = mtmd_default_marker();
|
||||||
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||||
params.image_min_tokens = -1;
|
params.image_min_tokens = -1;
|
||||||
params.image_max_tokens = -1;
|
params.image_max_tokens = -1;
|
||||||
return params;
|
return params;
|
||||||
|
|
@ -166,9 +175,11 @@ struct mtmd_context {
|
||||||
clip_context_params ctx_clip_params;
|
clip_context_params ctx_clip_params;
|
||||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||||
|
ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
|
||||||
// custom image token limits
|
// custom image token limits
|
||||||
ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
|
ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
|
||||||
ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
|
ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
|
||||||
|
|
||||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||||
ctx_v = res.ctx_v;
|
ctx_v = res.ctx_v;
|
||||||
ctx_a = res.ctx_a;
|
ctx_a = res.ctx_a;
|
||||||
|
|
@ -383,9 +394,7 @@ mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||||
}
|
}
|
||||||
|
|
||||||
void mtmd_free(mtmd_context * ctx) {
|
void mtmd_free(mtmd_context * ctx) {
|
||||||
if (ctx) {
|
|
||||||
delete ctx;
|
delete ctx;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct mtmd_tokenizer {
|
struct mtmd_tokenizer {
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,7 @@ struct mtmd_context_params {
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
const char * image_marker; // deprecated, use media_marker instead
|
const char * image_marker; // deprecated, use media_marker instead
|
||||||
const char * media_marker;
|
const char * media_marker;
|
||||||
|
enum llama_flash_attn_type flash_attn_type;
|
||||||
|
|
||||||
// limit number of image tokens, only for vision models with dynamic resolution
|
// limit number of image tokens, only for vision models with dynamic resolution
|
||||||
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -2407,7 +2407,7 @@ struct server_context {
|
||||||
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
params_dft.devices = params_base.speculative.devices;
|
||||||
params_dft.model = params_base.speculative.model;
|
params_dft.model = params_base.speculative.model;
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
params_dft.n_parallel = 1;
|
params_dft.n_parallel = 1;
|
||||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||||
|
|
@ -2456,6 +2456,7 @@ struct server_context {
|
||||||
mparams.print_timings = false;
|
mparams.print_timings = false;
|
||||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||||
|
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||||
|
|
@ -2497,10 +2498,16 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
||||||
|
|
||||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||||
|
|
||||||
|
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||||
|
|
||||||
|
int n_ctx_slot = llama_n_ctx_seq(ctx);
|
||||||
|
if (n_ctx_slot > n_ctx_train) {
|
||||||
|
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
||||||
|
n_ctx_slot = n_ctx_train;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
server_slot slot;
|
server_slot slot;
|
||||||
|
|
||||||
|
|
@ -2529,7 +2536,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||||
|
|
||||||
slot.callback_on_release = [this](int) {
|
slot.callback_on_release = [this](int) {
|
||||||
queue_tasks.pop_deferred_task();
|
queue_tasks.pop_deferred_task();
|
||||||
|
|
@ -2701,6 +2708,39 @@ struct server_context {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return true if at least one slot has been purged
|
||||||
|
// TODO: improve logic
|
||||||
|
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||||
|
// - move slot to level 2 cache instead of removing?
|
||||||
|
// - instead of purging, try to store and resume later?
|
||||||
|
bool try_purge_idle_slots() {
|
||||||
|
bool res = false;
|
||||||
|
|
||||||
|
if (!params_base.kv_unified) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
if (slot.is_processing()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.prompt.n_tokens() > 0) {
|
||||||
|
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
|
||||||
|
res = true;
|
||||||
|
|
||||||
|
// purge slots one by one
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
|
|
@ -3637,9 +3677,10 @@ struct server_context {
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
|
||||||
float alora_scale = -1.0f;
|
float alora_scale = -1.0f;
|
||||||
size_t alora_disabled_id = 0;
|
size_t alora_disabled_id = 0;
|
||||||
|
|
||||||
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
|
|
@ -3916,8 +3957,11 @@ struct server_context {
|
||||||
|
|
||||||
// truncate any tokens that are beyond n_past for this slot
|
// truncate any tokens that are beyond n_past for this slot
|
||||||
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
||||||
|
|
||||||
|
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||||
|
|
||||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
|
|
@ -3926,8 +3970,6 @@ struct server_context {
|
||||||
slot.prompt.tokens.clear();
|
slot.prompt.tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||||
// process the image
|
// process the image
|
||||||
|
|
@ -4128,6 +4170,8 @@ struct server_context {
|
||||||
std::string err;
|
std::string err;
|
||||||
|
|
||||||
if (n_batch == 1 && ret == 1) {
|
if (n_batch == 1 && ret == 1) {
|
||||||
|
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
||||||
|
// need to remove the tokens from the current batch too
|
||||||
err = "Context size has been exceeded.";
|
err = "Context size has been exceeded.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -4143,17 +4187,23 @@ struct server_context {
|
||||||
// TODO: handle ret == 2 (abort) when we start aborting
|
// TODO: handle ret == 2 (abort) when we start aborting
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
if (slot.is_processing()) {
|
||||||
send_error(slot, err);
|
send_error(slot, err);
|
||||||
slot.release();
|
slot.release();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
|
if (!try_purge_idle_slots()) {
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
|
}
|
||||||
|
|
||||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||||
|
|
||||||
|
|
@ -4393,6 +4443,15 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: should we have a separate n_parallel parameter for the server?
|
||||||
|
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
||||||
|
if (params.n_parallel == 1 && params.kv_unified == false) {
|
||||||
|
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
|
||||||
|
|
||||||
|
params.n_parallel = 4;
|
||||||
|
params.kv_unified = true;
|
||||||
|
}
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
|
||||||
// struct that contains llama context and inference
|
// struct that contains llama context and inference
|
||||||
|
|
@ -4946,7 +5005,7 @@ int main(int argc, char ** argv) {
|
||||||
// Everything else, including multimodal completions.
|
// Everything else, including multimodal completions.
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
|
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto n_prompt_tokens = inputs[i].size();
|
auto n_prompt_tokens = inputs[i].size();
|
||||||
|
|
|
||||||
|
|
@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"n_batch,batch_count,reuse_cache",
|
"n_batch,batch_count,reuse_cache",
|
||||||
[
|
[
|
||||||
(64, 15, False),
|
(64, 3, False),
|
||||||
(64, 1, True),
|
(64, 1, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_return_progresssss(n_batch, batch_count, reuse_cache):
|
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||||
global server
|
global server
|
||||||
server.n_batch = n_batch
|
server.n_batch = n_batch
|
||||||
server.n_ctx = 2048
|
server.n_ctx = 256
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
server.start()
|
server.start()
|
||||||
def make_cmpl_request():
|
def make_cmpl_request():
|
||||||
return server.make_stream_request("POST", "/chat/completions", data={
|
return server.make_stream_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "This is a test" * 100},
|
{"role": "user", "content": "This is a test" * 10},
|
||||||
],
|
],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"return_progress": True,
|
"return_progress": True,
|
||||||
|
|
|
||||||
|
|
@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||||
# assert match_regex(re_content, res.body["content"])
|
# assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"n_ctx,n_slots,n_predict_vals,expected_success",
|
||||||
|
[
|
||||||
|
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
|
||||||
|
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
|
||||||
|
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
|
||||||
|
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.kv_unified = True
|
||||||
|
server.n_ctx = n_ctx
|
||||||
|
server.start()
|
||||||
|
prompt = "A"
|
||||||
|
tasks = []
|
||||||
|
for n_predict in n_predict_vals:
|
||||||
|
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
|
||||||
|
if expect_ok:
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "content" in res.body
|
||||||
|
if "timings" in res.body:
|
||||||
|
assert res.body["timings"]["predicted_n"] == n_predict
|
||||||
|
else:
|
||||||
|
assert res.status_code == 500
|
||||||
|
assert "content" not in res.body
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prompt,n_predict,response_fields",
|
"prompt,n_predict,response_fields",
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ def test_infill_without_input_extra():
|
||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
|
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
def test_infill_with_input_extra():
|
def test_infill_with_input_extra():
|
||||||
|
|
@ -34,7 +34,7 @@ def test_infill_with_input_extra():
|
||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Dad|excited|park)+", res.body["content"])
|
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("input_extra", [
|
@pytest.mark.parametrize("input_extra", [
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,7 @@ class ServerProcess:
|
||||||
server_embeddings: bool | None = False
|
server_embeddings: bool | None = False
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
|
kv_unified: bool | None = False
|
||||||
server_slots: bool | None = False
|
server_slots: bool | None = False
|
||||||
pooling: str | None = None
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
|
|
@ -159,6 +160,8 @@ class ServerProcess:
|
||||||
server_args.append("--reranking")
|
server_args.append("--reranking")
|
||||||
if self.server_metrics:
|
if self.server_metrics:
|
||||||
server_args.append("--metrics")
|
server_args.append("--metrics")
|
||||||
|
if self.kv_unified:
|
||||||
|
server_args.append("--kv-unified")
|
||||||
if self.server_slots:
|
if self.server_slots:
|
||||||
server_args.append("--slots")
|
server_args.append("--slots")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1212,7 +1212,7 @@ public:
|
||||||
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
|
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
|
||||||
auto * chunk = tokens.map_idx_to_media[it->first].get();
|
auto * chunk = tokens.map_idx_to_media[it->first].get();
|
||||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||||
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
|
map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1244,6 +1244,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void clear() {
|
void clear() {
|
||||||
|
map_idx_to_media.clear();
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,8 @@
|
||||||
let displayedModel = $derived((): string | null => {
|
let displayedModel = $derived((): string | null => {
|
||||||
if (!currentConfig.showModelInfo) return null;
|
if (!currentConfig.showModelInfo) return null;
|
||||||
|
|
||||||
if (currentConfig.modelSelectorEnabled) {
|
if (message.model) {
|
||||||
return message.model ?? null;
|
return message.model;
|
||||||
}
|
}
|
||||||
|
|
||||||
return serverModel;
|
return serverModel;
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ export class ChatService {
|
||||||
onError,
|
onError,
|
||||||
onReasoningChunk,
|
onReasoningChunk,
|
||||||
onModel,
|
onModel,
|
||||||
|
onFirstValidChunk,
|
||||||
// Generation parameters
|
// Generation parameters
|
||||||
temperature,
|
temperature,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
|
|
@ -201,6 +202,7 @@ export class ChatService {
|
||||||
onError,
|
onError,
|
||||||
onReasoningChunk,
|
onReasoningChunk,
|
||||||
onModel,
|
onModel,
|
||||||
|
onFirstValidChunk,
|
||||||
conversationId,
|
conversationId,
|
||||||
abortController.signal
|
abortController.signal
|
||||||
);
|
);
|
||||||
|
|
@ -267,6 +269,7 @@ export class ChatService {
|
||||||
onError?: (error: Error) => void,
|
onError?: (error: Error) => void,
|
||||||
onReasoningChunk?: (chunk: string) => void,
|
onReasoningChunk?: (chunk: string) => void,
|
||||||
onModel?: (model: string) => void,
|
onModel?: (model: string) => void,
|
||||||
|
onFirstValidChunk?: () => void,
|
||||||
conversationId?: string,
|
conversationId?: string,
|
||||||
abortSignal?: AbortSignal
|
abortSignal?: AbortSignal
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
|
|
@ -283,6 +286,7 @@ export class ChatService {
|
||||||
let lastTimings: ChatMessageTimings | undefined;
|
let lastTimings: ChatMessageTimings | undefined;
|
||||||
let streamFinished = false;
|
let streamFinished = false;
|
||||||
let modelEmitted = false;
|
let modelEmitted = false;
|
||||||
|
let firstValidChunkEmitted = false;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let chunk = '';
|
let chunk = '';
|
||||||
|
|
@ -311,10 +315,12 @@ export class ChatService {
|
||||||
try {
|
try {
|
||||||
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
|
||||||
|
|
||||||
const chunkModel = this.extractModelName(parsed);
|
if (!firstValidChunkEmitted && parsed.object === 'chat.completion.chunk') {
|
||||||
if (chunkModel && !modelEmitted) {
|
firstValidChunkEmitted = true;
|
||||||
modelEmitted = true;
|
|
||||||
onModel?.(chunkModel);
|
if (!abortSignal?.aborted) {
|
||||||
|
onFirstValidChunk?.();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const content = parsed.choices[0]?.delta?.content;
|
const content = parsed.choices[0]?.delta?.content;
|
||||||
|
|
@ -322,6 +328,12 @@ export class ChatService {
|
||||||
const timings = parsed.timings;
|
const timings = parsed.timings;
|
||||||
const promptProgress = parsed.prompt_progress;
|
const promptProgress = parsed.prompt_progress;
|
||||||
|
|
||||||
|
const chunkModel = this.extractModelName(parsed);
|
||||||
|
if (chunkModel && !modelEmitted) {
|
||||||
|
modelEmitted = true;
|
||||||
|
onModel?.(chunkModel);
|
||||||
|
}
|
||||||
|
|
||||||
if (timings || promptProgress) {
|
if (timings || promptProgress) {
|
||||||
this.updateProcessingState(timings, promptProgress, conversationId);
|
this.updateProcessingState(timings, promptProgress, conversationId);
|
||||||
if (timings) {
|
if (timings) {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import { DatabaseStore } from '$lib/stores/database';
|
import { DatabaseStore } from '$lib/stores/database';
|
||||||
import { chatService, slotsService } from '$lib/services';
|
import { chatService, slotsService } from '$lib/services';
|
||||||
import { config } from '$lib/stores/settings.svelte';
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
|
import { serverStore } from '$lib/stores/server.svelte';
|
||||||
import { normalizeModelName } from '$lib/utils/model-names';
|
import { normalizeModelName } from '$lib/utils/model-names';
|
||||||
import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching';
|
import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/utils/branching';
|
||||||
import { browser } from '$app/environment';
|
import { browser } from '$app/environment';
|
||||||
|
|
@ -362,9 +363,41 @@ class ChatStore {
|
||||||
|
|
||||||
let resolvedModel: string | null = null;
|
let resolvedModel: string | null = null;
|
||||||
let modelPersisted = false;
|
let modelPersisted = false;
|
||||||
|
const currentConfig = config();
|
||||||
|
const preferServerPropsModel = !currentConfig.modelSelectorEnabled;
|
||||||
|
let serverPropsRefreshed = false;
|
||||||
|
let updateModelFromServerProps: ((persistImmediately?: boolean) => void) | null = null;
|
||||||
|
|
||||||
const recordModel = (modelName: string, persistImmediately = true): void => {
|
const refreshServerPropsOnce = () => {
|
||||||
const normalizedModel = normalizeModelName(modelName);
|
if (serverPropsRefreshed) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
serverPropsRefreshed = true;
|
||||||
|
|
||||||
|
const hasExistingProps = serverStore.serverProps !== null;
|
||||||
|
|
||||||
|
serverStore
|
||||||
|
.fetchServerProps({ silent: hasExistingProps })
|
||||||
|
.then(() => {
|
||||||
|
updateModelFromServerProps?.(true);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.warn('Failed to refresh server props after streaming started:', error);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const recordModel = (modelName: string | null | undefined, persistImmediately = true): void => {
|
||||||
|
const serverModelName = serverStore.modelName;
|
||||||
|
const preferredModelSource = preferServerPropsModel
|
||||||
|
? (serverModelName ?? modelName ?? null)
|
||||||
|
: (modelName ?? serverModelName ?? null);
|
||||||
|
|
||||||
|
if (!preferredModelSource) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedModel = normalizeModelName(preferredModelSource);
|
||||||
|
|
||||||
if (!normalizedModel || normalizedModel === resolvedModel) {
|
if (!normalizedModel || normalizedModel === resolvedModel) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -388,6 +421,20 @@ class ChatStore {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (preferServerPropsModel) {
|
||||||
|
updateModelFromServerProps = (persistImmediately = true) => {
|
||||||
|
const currentServerModel = serverStore.modelName;
|
||||||
|
|
||||||
|
if (!currentServerModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
recordModel(currentServerModel, persistImmediately);
|
||||||
|
};
|
||||||
|
|
||||||
|
updateModelFromServerProps(false);
|
||||||
|
}
|
||||||
|
|
||||||
slotsService.startStreaming();
|
slotsService.startStreaming();
|
||||||
slotsService.setActiveConversation(assistantMessage.convId);
|
slotsService.setActiveConversation(assistantMessage.convId);
|
||||||
|
|
||||||
|
|
@ -396,6 +443,9 @@ class ChatStore {
|
||||||
{
|
{
|
||||||
...this.getApiOptions(),
|
...this.getApiOptions(),
|
||||||
|
|
||||||
|
onFirstValidChunk: () => {
|
||||||
|
refreshServerPropsOnce();
|
||||||
|
},
|
||||||
onChunk: (chunk: string) => {
|
onChunk: (chunk: string) => {
|
||||||
streamedContent += chunk;
|
streamedContent += chunk;
|
||||||
this.setConversationStreaming(
|
this.setConversationStreaming(
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ class ServerStore {
|
||||||
private _error = $state<string | null>(null);
|
private _error = $state<string | null>(null);
|
||||||
private _serverWarning = $state<string | null>(null);
|
private _serverWarning = $state<string | null>(null);
|
||||||
private _slotsEndpointAvailable = $state<boolean | null>(null);
|
private _slotsEndpointAvailable = $state<boolean | null>(null);
|
||||||
|
private fetchServerPropsPromise: Promise<void> | null = null;
|
||||||
|
|
||||||
private readCachedServerProps(): ApiLlamaCppServerProps | null {
|
private readCachedServerProps(): ApiLlamaCppServerProps | null {
|
||||||
if (!browser) return null;
|
if (!browser) return null;
|
||||||
|
|
@ -171,57 +172,63 @@ class ServerStore {
|
||||||
/**
|
/**
|
||||||
* Fetches server properties from the server
|
* Fetches server properties from the server
|
||||||
*/
|
*/
|
||||||
async fetchServerProps(): Promise<void> {
|
async fetchServerProps(options: { silent?: boolean } = {}): Promise<void> {
|
||||||
|
const { silent = false } = options;
|
||||||
|
const isSilent = silent && this._serverProps !== null;
|
||||||
|
|
||||||
|
if (this.fetchServerPropsPromise) {
|
||||||
|
return this.fetchServerPropsPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isSilent) {
|
||||||
this._loading = true;
|
this._loading = true;
|
||||||
this._error = null;
|
this._error = null;
|
||||||
this._serverWarning = null;
|
this._serverWarning = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const hadProps = this._serverProps !== null;
|
||||||
|
|
||||||
|
const fetchPromise = (async () => {
|
||||||
try {
|
try {
|
||||||
console.log('Fetching server properties...');
|
|
||||||
const props = await ChatService.getServerProps();
|
const props = await ChatService.getServerProps();
|
||||||
this._serverProps = props;
|
this._serverProps = props;
|
||||||
this.persistServerProps(props);
|
this.persistServerProps(props);
|
||||||
console.log('Server properties loaded:', props);
|
this._error = null;
|
||||||
|
this._serverWarning = null;
|
||||||
// Check slots endpoint availability after server props are loaded
|
|
||||||
await this.checkSlotsEndpointAvailability();
|
await this.checkSlotsEndpointAvailability();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const hadCachedProps = this._serverProps !== null;
|
if (isSilent && hadProps) {
|
||||||
let errorMessage = 'Failed to connect to server';
|
console.warn('Silent server props refresh failed, keeping cached data:', error);
|
||||||
let isOfflineLikeError = false;
|
return;
|
||||||
let isServerSideError = false;
|
}
|
||||||
|
|
||||||
if (error instanceof Error) {
|
this.handleFetchServerPropsError(error, hadProps);
|
||||||
// Handle specific error types with user-friendly messages
|
} finally {
|
||||||
if (error.name === 'TypeError' && error.message.includes('fetch')) {
|
if (!isSilent) {
|
||||||
errorMessage = 'Server is not running or unreachable';
|
this._loading = false;
|
||||||
isOfflineLikeError = true;
|
|
||||||
} else if (error.message.includes('ECONNREFUSED')) {
|
|
||||||
errorMessage = 'Connection refused - server may be offline';
|
|
||||||
isOfflineLikeError = true;
|
|
||||||
} else if (error.message.includes('ENOTFOUND')) {
|
|
||||||
errorMessage = 'Server not found - check server address';
|
|
||||||
isOfflineLikeError = true;
|
|
||||||
} else if (error.message.includes('ETIMEDOUT')) {
|
|
||||||
errorMessage = 'Request timed out - the server took too long to respond';
|
|
||||||
isOfflineLikeError = true;
|
|
||||||
} else if (error.message.includes('503')) {
|
|
||||||
errorMessage = 'Server temporarily unavailable - try again shortly';
|
|
||||||
isServerSideError = true;
|
|
||||||
} else if (error.message.includes('500')) {
|
|
||||||
errorMessage = 'Server error - check server logs';
|
|
||||||
isServerSideError = true;
|
|
||||||
} else if (error.message.includes('404')) {
|
|
||||||
errorMessage = 'Server endpoint not found';
|
|
||||||
} else if (error.message.includes('403') || error.message.includes('401')) {
|
|
||||||
errorMessage = 'Access denied';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.fetchServerPropsPromise = null;
|
||||||
}
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
this.fetchServerPropsPromise = fetchPromise;
|
||||||
|
|
||||||
|
await fetchPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles fetch failures by attempting to recover cached server props and
|
||||||
|
* updating the user-facing error or warning state appropriately.
|
||||||
|
*/
|
||||||
|
private handleFetchServerPropsError(error: unknown, hadProps: boolean): void {
|
||||||
|
const { errorMessage, isOfflineLikeError, isServerSideError } = this.normalizeFetchError(error);
|
||||||
|
|
||||||
let cachedProps: ApiLlamaCppServerProps | null = null;
|
let cachedProps: ApiLlamaCppServerProps | null = null;
|
||||||
|
|
||||||
if (!hadCachedProps) {
|
if (!hadProps) {
|
||||||
cachedProps = this.readCachedServerProps();
|
cachedProps = this.readCachedServerProps();
|
||||||
|
|
||||||
if (cachedProps) {
|
if (cachedProps) {
|
||||||
this._serverProps = cachedProps;
|
this._serverProps = cachedProps;
|
||||||
this._error = null;
|
this._error = null;
|
||||||
|
|
@ -249,10 +256,48 @@ class ServerStore {
|
||||||
errorMessage
|
errorMessage
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
console.error('Error fetching server properties:', error);
|
console.error('Error fetching server properties:', error);
|
||||||
} finally {
|
|
||||||
this._loading = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private normalizeFetchError(error: unknown): {
|
||||||
|
errorMessage: string;
|
||||||
|
isOfflineLikeError: boolean;
|
||||||
|
isServerSideError: boolean;
|
||||||
|
} {
|
||||||
|
let errorMessage = 'Failed to connect to server';
|
||||||
|
let isOfflineLikeError = false;
|
||||||
|
let isServerSideError = false;
|
||||||
|
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const message = error.message || '';
|
||||||
|
|
||||||
|
if (error.name === 'TypeError' && message.includes('fetch')) {
|
||||||
|
errorMessage = 'Server is not running or unreachable';
|
||||||
|
isOfflineLikeError = true;
|
||||||
|
} else if (message.includes('ECONNREFUSED')) {
|
||||||
|
errorMessage = 'Connection refused - server may be offline';
|
||||||
|
isOfflineLikeError = true;
|
||||||
|
} else if (message.includes('ENOTFOUND')) {
|
||||||
|
errorMessage = 'Server not found - check server address';
|
||||||
|
isOfflineLikeError = true;
|
||||||
|
} else if (message.includes('ETIMEDOUT')) {
|
||||||
|
errorMessage = 'Request timed out - the server took too long to respond';
|
||||||
|
isOfflineLikeError = true;
|
||||||
|
} else if (message.includes('503')) {
|
||||||
|
errorMessage = 'Server temporarily unavailable - try again shortly';
|
||||||
|
isServerSideError = true;
|
||||||
|
} else if (message.includes('500')) {
|
||||||
|
errorMessage = 'Server error - check server logs';
|
||||||
|
isServerSideError = true;
|
||||||
|
} else if (message.includes('404')) {
|
||||||
|
errorMessage = 'Server endpoint not found';
|
||||||
|
} else if (message.includes('403') || message.includes('401')) {
|
||||||
|
errorMessage = 'Access denied';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { errorMessage, isOfflineLikeError, isServerSideError };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -264,6 +309,7 @@ class ServerStore {
|
||||||
this._serverWarning = null;
|
this._serverWarning = null;
|
||||||
this._loading = false;
|
this._loading = false;
|
||||||
this._slotsEndpointAvailable = null;
|
this._slotsEndpointAvailable = null;
|
||||||
|
this.fetchServerPropsPromise = null;
|
||||||
this.persistServerProps(null);
|
this.persistServerProps(null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,7 @@ export interface ApiChatCompletionRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ApiChatCompletionStreamChunk {
|
export interface ApiChatCompletionStreamChunk {
|
||||||
|
object?: string;
|
||||||
model?: string;
|
model?: string;
|
||||||
choices: Array<{
|
choices: Array<{
|
||||||
model?: string;
|
model?: string;
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ export interface SettingsChatServiceOptions {
|
||||||
onChunk?: (chunk: string) => void;
|
onChunk?: (chunk: string) => void;
|
||||||
onReasoningChunk?: (chunk: string) => void;
|
onReasoningChunk?: (chunk: string) => void;
|
||||||
onModel?: (model: string) => void;
|
onModel?: (model: string) => void;
|
||||||
|
onFirstValidChunk?: () => void;
|
||||||
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
|
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
|
||||||
onError?: (error: Error) => void;
|
onError?: (error: Error) => void;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue