Merge branch 'ggml-org:master' into Kimi-Linear
This commit is contained in:
commit
a46782c1b7
|
|
@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
|
|
@ -9,12 +5,12 @@
|
|||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
|
|
@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::u32string filename_utf32;
|
||||
try {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
size_t offset = 0;
|
||||
while (offset < filename.size()) {
|
||||
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
|
||||
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
||||
// or invalid encodings were encountered. Reject such attempts
|
||||
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
||||
if (filename_reencoded != filename) {
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
uint32_t c = result.codepoint;
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
for (char32_t c : filename_utf32) {
|
||||
if ((result.bytes_consumed == 2 && c < 0x80) ||
|
||||
(result.bytes_consumed == 3 && c < 0x800) ||
|
||||
(result.bytes_consumed == 4 && c < 0x10000)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
if (c <= 0x1F // Control characters (C0)
|
||||
|| c == 0x7F // Control characters (DEL)
|
||||
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
||||
|
|
@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
|||
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
||||
|| c == 0x2216 // Set Minus (backslash equivalent)
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c > 0x10FFFF // Max Unicode limit
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
|
|
@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
|||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
offset += result.bytes_consumed;
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ Adapt below build commands accordingly.
|
|||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
|
|
|
|||
|
|
@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
#if defined(GGML_USE_HIP)
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
|
||||
#else
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
#endif
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
|
|
@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
|
||||
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
|
||||
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
|
||||
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
|
||||
#else
|
||||
const half * K_h_f16 = K_h;
|
||||
const half * V_h_f16 = V_h;
|
||||
half * KQ_f16 = KQ;
|
||||
half * VKQ_f16 = VKQ;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
|
@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
|
|
@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
KQ_f16 + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
|
|
@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
|
|
@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, wmma::mem_col_major);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,25 +64,12 @@ struct htp_ops_context {
|
|||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
|
||||
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
|
||||
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
|
||||
|
||||
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
|
||||
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
|
||||
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
|
|||
// otherwise we'll release it once we're done with the current Op.
|
||||
|
||||
if (ctx->vtcm_inuse) {
|
||||
ctx->vtcm_needs_release = false;
|
||||
ctx->vtcm_needs_release = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -264,15 +264,25 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
|||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_REPEAT:
|
||||
return true;
|
||||
default:
|
||||
return ggml_op_is_empty(op);
|
||||
|
|
@ -312,7 +322,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
|||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 64;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
|
|
|
|||
|
|
@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
|
|||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
op_str = "sum_rows"; break;
|
||||
case GGML_OP_MEAN:
|
||||
op_str = "mean"; break;
|
||||
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
|
||||
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d", base, op_num);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
if (is_c4) {
|
||||
res.smem *= 4;
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@
|
|||
#define FC_COUNT_EQUAL 1100
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
|
|
@ -118,6 +119,8 @@
|
|||
#define OP_UNARY_NUM_SOFTPLUS 115
|
||||
#define OP_UNARY_NUM_EXPM1 116
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
|
|
|
|||
|
|
@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_set(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
|
@ -904,6 +908,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
|
|
@ -925,21 +934,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
nth = std::min(nth, (int) args.ne00);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
|
|
@ -1599,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
||||
const size_t offs = ((const int32_t *) op->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
||||
|
||||
int64_t nk0 = ne10;
|
||||
if (ggml_is_quantized(op->src[1]->type)) {
|
||||
nk0 = ne10/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb10,
|
||||
/*.nb01 =*/ nb11,
|
||||
/*.nb02 =*/ nb12,
|
||||
/*.nb03 =*/ nb13,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ ggml_element_size(op),
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
bid_dst.offs += offs;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
|
|||
return x*y;
|
||||
}
|
||||
|
||||
static inline float sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float sum(float4 x) {
|
||||
return x[0] + x[1] + x[2] + x[3];
|
||||
}
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
|
|
@ -1501,33 +1509,35 @@ kernel void kernel_op_sum_f32(
|
|||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_sum_rows_impl(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
#define FC_OP FC_sum_rows_op
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
shmem_t[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float sumf = 0;
|
||||
T0 sumf = T0(0.0f);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
|
|
@ -1538,23 +1548,33 @@ kernel void kernel_sum_rows(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
shmem_t[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = shmem_t[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
||||
if (is_same<float4, T0>::value) {
|
||||
dst_row[0] = sum(sumf) / (4*args.ne00);
|
||||
} else {
|
||||
dst_row[0] = sum(sumf) / args.ne00;
|
||||
}
|
||||
} else {
|
||||
dst_row[0] = sum(sumf);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
|
|
@ -2435,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
|
|||
const short K = FC_solve_tri_k;
|
||||
const short NP = PAD2(N, NW);
|
||||
|
||||
const int32_t ne02 = args.ne02;
|
||||
const int32_t ne03 = args.ne03;
|
||||
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x*NSG + sgitg;
|
||||
|
|
@ -5949,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
|
||||
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
|
||||
|
||||
const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
//const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
||||
|
|
@ -8537,7 +8554,9 @@ kernel void kernel_mul_mm(
|
|||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
|
@ -8660,8 +8679,8 @@ kernel void kernel_mul_mm(
|
|||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
|
|
@ -8910,7 +8929,9 @@ kernel void kernel_mul_mm_id(
|
|||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
|
@ -9045,8 +9066,8 @@ kernel void kernel_mul_mm_id(
|
|||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ set(GGML_OPENCL_KERNELS
|
|||
mul_mv_q4_0_f32_8x_flat
|
||||
mul_mv_q4_0_f32_1d_8x_flat
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q4_1_f32
|
||||
mul_mv_q4_1_f32_flat
|
||||
mul_mv_q4_k_f32
|
||||
mul_mv_q6_k_f32
|
||||
mul_mv_q6_k_f32_flat
|
||||
|
|
@ -101,6 +103,8 @@ set(GGML_OPENCL_KERNELS
|
|||
gemv_moe_mxfp4_f32
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
|
|
|
|||
|
|
@ -525,6 +525,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_mul_mm_f16_f32_kq;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
|
|
@ -532,6 +533,8 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q4_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32_flat;
|
||||
|
|
@ -564,6 +567,8 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
|
||||
|
||||
|
|
@ -888,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
|
||||
|
|
@ -1119,6 +1126,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_1_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q4_1_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_1_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q4_1_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
|
|
@ -1361,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q4_0_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q4_0_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q4_1_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q4_1_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
|
|
@ -2923,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 {
|
|||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q4_1 {
|
||||
// Quantized values.
|
||||
cl_mem q = nullptr;
|
||||
// Quantized values in image1d_buffer_t.
|
||||
cl_mem q_img = nullptr;
|
||||
// Scales.
|
||||
cl_mem d = nullptr;
|
||||
// Scales in image1d_buffer_t.
|
||||
cl_mem d_img = nullptr;
|
||||
// Min
|
||||
cl_mem m = nullptr;
|
||||
// Min in image1d_buffer_t.
|
||||
cl_mem m_img = nullptr;
|
||||
// Size of quantized values.
|
||||
size_t size_q = 0;
|
||||
// Size of scales.
|
||||
size_t size_d = 0;
|
||||
// Size of min values.
|
||||
size_t size_m = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_q4_1() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
|
||||
// They must be properly released so that the original buffer can be
|
||||
// properly released to avoid memory leak.
|
||||
if (q != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q));
|
||||
q = nullptr;
|
||||
}
|
||||
if (d != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(d));
|
||||
d = nullptr;
|
||||
}
|
||||
if (m != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(m));
|
||||
m = nullptr;
|
||||
}
|
||||
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
|
||||
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
|
||||
// So, there is no need to release them here.
|
||||
// TODO: initialize them for non SMALL_PATH path, or remove them.
|
||||
q_img = nullptr;
|
||||
d_img = nullptr;
|
||||
m_img = nullptr;
|
||||
size_q = 0;
|
||||
size_d = 0;
|
||||
size_m = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_mxfp4 {
|
||||
// Quantized values.
|
||||
cl_mem q = nullptr;
|
||||
|
|
@ -3399,8 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|||
return true;
|
||||
} else if (op->src[0]->type == GGML_TYPE_F32) {
|
||||
return op->src[1]->type == GGML_TYPE_F32;
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K) {
|
||||
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
|
|
@ -3629,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context {
|
|||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() {
|
||||
ggml_tensor_extra_cl_q4_1 * extra;
|
||||
if (temp_tensor_extras_q4_1.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_q4_1();
|
||||
} else {
|
||||
extra = temp_tensor_extras_q4_1.back();
|
||||
temp_tensor_extras_q4_1.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_q4_1_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra;
|
||||
if (temp_tensor_extras_mxfp4.empty()) {
|
||||
|
|
@ -3685,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context {
|
|||
}
|
||||
temp_tensor_extras_q4_0_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {
|
||||
temp_tensor_extras_q4_1.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q4_1_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
|
||||
temp_tensor_extras_mxfp4.push_back(e);
|
||||
}
|
||||
|
|
@ -3710,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context {
|
|||
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1;
|
||||
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
|
||||
|
|
@ -4079,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
return;
|
||||
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_1) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
// Allocate the new extra and create aliases from the original.
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1();
|
||||
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
cl_buffer_region region;
|
||||
|
||||
// The original tensor memory is divided into scales and quants, i.e.,
|
||||
// we first store scales, mins, then quants.
|
||||
// Create subbuffer for scales.
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_d;
|
||||
extra->d = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for mins.
|
||||
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
|
||||
region.size = size_m;
|
||||
extra->m = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for quants.
|
||||
region.origin = align_to(previous_origin + size_m, backend_ctx->alignment);
|
||||
region.size = size_q;
|
||||
extra->q = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
|
@ -4581,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_1) {
|
||||
ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
|
|
@ -8409,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
|
||||
|
|
@ -8922,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q4_0: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q4_1: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
|
|
@ -9262,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3));
|
||||
#else
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mv_q4_1_f32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;
|
||||
|
|
|
|||
|
|
@ -46,6 +46,15 @@ struct block_q4_0
|
|||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_1
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q6_K
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
|
|||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_1
|
||||
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q4_1(
|
||||
global struct block_q4_1 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
global half * dst_m
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * m = (global half *) dst_m + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*m = b->m;
|
||||
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_1(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global half * src_m,
|
||||
global struct block_q4_1 * dst
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * m = (global half *) src_m + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->m = *m;
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_mxfp4
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global half * src0_m,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
float m = (float)src0_m[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y(
|
||||
global const struct block_q4_1 * qb_curr,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y_flat(
|
||||
global const uchar * x,
|
||||
global const half * dh,
|
||||
global const half * mh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
float m = *mh;
|
||||
global const ushort * qs = ((global const ushort *) x + il/2);
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
// The number of scales/mins is the same as the number of blocks.
|
||||
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
|
||||
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_dm;
|
||||
global half * m = (global half *) src0_m + offset0_dm;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
|
|
@ -1150,9 +1150,9 @@ extern "C" {
|
|||
//
|
||||
|
||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||
///
|
||||
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
||||
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||
/// @param tmpl A Jinja template to use for this chat.
|
||||
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||
/// @param n_msg Number of llama_chat_message in this chat
|
||||
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
|
||||
|
|
|
|||
|
|
@ -30,12 +30,18 @@ fi
|
|||
PR=$1
|
||||
[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; }
|
||||
|
||||
url_origin=$(git config --get remote.upstream.url 2>/dev/null) || \
|
||||
url_origin=$(git config --get remote.origin.url) || {
|
||||
echo "error: no remote named 'origin' in this repository"
|
||||
echo "error: no remote named 'upstream' or 'origin' in this repository"
|
||||
exit 1
|
||||
}
|
||||
|
||||
org_repo=$(echo $url_origin | cut -d/ -f4-)
|
||||
# Extract org/repo from either https or ssh format.
|
||||
if [[ $url_origin =~ ^git@ ]]; then
|
||||
org_repo=$(echo $url_origin | cut -d: -f2)
|
||||
else
|
||||
org_repo=$(echo $url_origin | cut -d/ -f4-)
|
||||
fi
|
||||
org_repo=${org_repo%.git}
|
||||
|
||||
echo "org/repo: $org_repo"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
import urllib.request
|
||||
|
||||
HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp",
|
||||
|
|
@ -12,8 +14,8 @@ vendor = {
|
|||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
|
||||
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE",
|
||||
|
||||
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7965,7 +7965,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
cparams.n_seq_max,
|
||||
nullptr);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
|
||||
// The main difference between hybrid architectures is the
|
||||
// layer filters, so pick the right one here
|
||||
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
|
||||
|
|
@ -7990,7 +7989,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_swa_full */ params.swa_full,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_kv_size */ cparams.n_ctx_seq,
|
||||
/* attn_n_ubatch */ cparams.n_ubatch,
|
||||
/* attn_n_pad */ 1,
|
||||
/* recurrent_type_r */ GGML_TYPE_F32,
|
||||
|
|
@ -8007,7 +8006,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_kv_size */ cparams.n_ctx_seq,
|
||||
/* attn_n_pad */ 1,
|
||||
/* attn_n_swa */ hparams.n_swa,
|
||||
/* attn_swa_type */ hparams.swa_type,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "unicode.h"
|
||||
#include "unicode-data.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <codecvt>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <locale>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
|
|
@ -199,27 +193,6 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
|||
return map;
|
||||
}
|
||||
|
||||
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
return conv.from_bytes(s);
|
||||
}
|
||||
|
||||
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||
std::vector<std::string> bpe_encoded_words;
|
||||
for (const auto & word : bpe_words) {
|
||||
|
|
@ -1028,10 +1001,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
break;
|
||||
}
|
||||
}
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
|
||||
if (use_collapsed) {
|
||||
// sanity-check that the original regex does not contain any non-ASCII characters
|
||||
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||
if (cpts_regex[i] >= 128) {
|
||||
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
||||
|
|
@ -1087,7 +1060,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||
} else {
|
||||
// no unicode category used, we can use std::wregex directly
|
||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||
std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end());
|
||||
|
||||
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||
std::wstring wtext(cpts.begin(), cpts.end());
|
||||
|
|
|
|||
|
|
@ -2786,9 +2786,10 @@ struct test_set : public test_case {
|
|||
const ggml_type type_dst;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int dim;
|
||||
const bool inplace;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type_src, type_dst, ne, dim);
|
||||
return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
|
|
@ -2796,8 +2797,8 @@ struct test_set : public test_case {
|
|||
}
|
||||
|
||||
test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
|
||||
|
|
@ -2808,7 +2809,7 @@ struct test_set : public test_case {
|
|||
for (int i = 0; i < dim; ++i) {
|
||||
ne_dst[i] *= 2;
|
||||
}
|
||||
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_set_param(dst);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
|
|
@ -2816,9 +2817,16 @@ struct test_set : public test_case {
|
|||
for (int i = 0; i < dim; ++i) {
|
||||
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
|
||||
}
|
||||
ggml_tensor * out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
ggml_tensor * out;
|
||||
if (inplace) {
|
||||
out = ggml_set_inplace(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
} else {
|
||||
out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
}
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -7428,11 +7436,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
// same-type copy
|
||||
|
|
@ -8132,24 +8142,30 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp.
|
|||
* Speculative decoding
|
||||
* Easy-to-use web UI
|
||||
|
||||
For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
|
||||
For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
|
||||
|
||||
## Usage
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -139,6 +139,6 @@ sequenceDiagram
|
|||
|
||||
Note over settingsStore: UI-only (not synced):
|
||||
rect rgb(255, 240, 240)
|
||||
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningFormat
|
||||
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch
|
||||
end
|
||||
```
|
||||
|
|
|
|||
|
|
@ -14,11 +14,11 @@
|
|||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.205 0 0);
|
||||
--primary-foreground: oklch(0.985 0 0);
|
||||
--secondary: oklch(0.97 0 0);
|
||||
--secondary: oklch(0.95 0 0);
|
||||
--secondary-foreground: oklch(0.205 0 0);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.97 0 0);
|
||||
--accent: oklch(0.95 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.577 0.245 27.325);
|
||||
--border: oklch(0.875 0 0);
|
||||
|
|
@ -37,7 +37,7 @@
|
|||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.975 0 0);
|
||||
--code-background: oklch(0.985 0 0);
|
||||
--code-foreground: oklch(0.145 0 0);
|
||||
--layer-popover: 1000000;
|
||||
}
|
||||
|
|
@ -51,7 +51,7 @@
|
|||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.922 0 0);
|
||||
--primary-foreground: oklch(0.205 0 0);
|
||||
--secondary: oklch(0.269 0 0);
|
||||
--secondary: oklch(0.29 0 0);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
|
|
@ -116,12 +116,62 @@
|
|||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
}
|
||||
|
||||
:root {
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@media (min-width: 640px) {
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
}
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
* {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: transparent transparent;
|
||||
transition: scrollbar-color 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover {
|
||||
scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
transition: background 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover::-webkit-scrollbar-thumb {
|
||||
background: hsl(var(--muted-foreground) / 0.3);
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb:hover {
|
||||
background: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
icon: Component;
|
||||
tooltip: string;
|
||||
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
|
||||
size?: 'default' | 'sm' | 'lg' | 'icon';
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onclick: () => void;
|
||||
'aria-label'?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
onclick,
|
||||
'aria-label': ariaLabel
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
{onclick}
|
||||
class="h-6 w-6 p-0 {className} flex"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<IconComponent class="h-3 w-3" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
<script lang="ts">
|
||||
import { Copy } from '@lucide/svelte';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
ariaLabel?: string;
|
||||
canCopy?: boolean;
|
||||
text: string;
|
||||
}
|
||||
|
||||
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Copy
|
||||
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
|
||||
aria-label={ariaLabel}
|
||||
onclick={() => canCopy && copyToClipboard(text)}
|
||||
/>
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
<script lang="ts">
|
||||
import { Eye } from '@lucide/svelte';
|
||||
import ActionIconCopyToClipboard from '$lib/components/app/actions/ActionIconCopyToClipboard.svelte';
|
||||
import { FileTypeText } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language: string;
|
||||
disabled?: boolean;
|
||||
onPreview?: (code: string, language: string) => void;
|
||||
}
|
||||
|
||||
let { code, language, disabled = false, onPreview }: Props = $props();
|
||||
|
||||
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
|
||||
|
||||
function handlePreview() {
|
||||
if (disabled) return;
|
||||
onPreview?.(code, language);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="code-block-actions">
|
||||
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
|
||||
<ActionIconCopyToClipboard
|
||||
text={code}
|
||||
canCopy={!disabled}
|
||||
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if showPreview}
|
||||
<button
|
||||
class="preview-code-btn"
|
||||
class:opacity-50={disabled}
|
||||
class:!cursor-not-allowed={disabled}
|
||||
title={disabled ? 'Code incomplete' : 'Preview code'}
|
||||
aria-label="Preview code"
|
||||
aria-disabled={disabled}
|
||||
type="button"
|
||||
onclick={handlePreview}
|
||||
>
|
||||
<Eye size={16} />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
/**
|
||||
*
|
||||
* ACTIONS
|
||||
*
|
||||
* Small interactive components for user actions.
|
||||
*
|
||||
*/
|
||||
|
||||
/** Styled icon button for action triggers with tooltip. */
|
||||
export { default as ActionIcon } from './ActionIcon.svelte';
|
||||
|
||||
/** Code block actions component (copy, preview). */
|
||||
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
|
||||
|
||||
/** Copy-to-clipboard icon button with click handler. */
|
||||
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
|
||||
|
||||
/** Remove/delete icon button with X icon. */
|
||||
export { default as ActionIconRemove } from './ActionIconRemove.svelte';
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
<script lang="ts">
|
||||
import { BadgeInfo } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
icon: Component;
|
||||
value: string | number;
|
||||
tooltipLabel?: string;
|
||||
}
|
||||
|
||||
let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
|
||||
|
||||
function handleClick() {
|
||||
void copyToClipboard(String(value));
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if tooltipLabel}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{tooltipLabel}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
<script lang="ts">
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import type { Snippet } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children: Snippet;
|
||||
class?: string;
|
||||
icon?: Snippet;
|
||||
onclick?: () => void;
|
||||
}
|
||||
|
||||
let { children, class: className = '', icon, onclick }: Props = $props();
|
||||
</script>
|
||||
|
||||
<button
|
||||
class={cn(
|
||||
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
|
||||
className
|
||||
)}
|
||||
{onclick}
|
||||
>
|
||||
{#if icon}
|
||||
{@render icon()}
|
||||
{/if}
|
||||
|
||||
{@render children()}
|
||||
</button>
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
<script lang="ts">
|
||||
import { ModelModality } from '$lib/enums';
|
||||
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants/icons';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
|
||||
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
|
||||
|
||||
interface Props {
|
||||
modalities: ModelModality[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { modalities, class: className = '' }: Props = $props();
|
||||
|
||||
// Filter to only modalities that have icons (VISION, AUDIO)
|
||||
const displayableModalities = $derived(
|
||||
modalities.filter(
|
||||
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
|
||||
)
|
||||
);
|
||||
</script>
|
||||
|
||||
{#each displayableModalities as modality, index (index)}
|
||||
{@const IconComponent = MODALITY_ICONS[modality]}
|
||||
{@const label = MODALITY_LABELS[modality]}
|
||||
|
||||
<span
|
||||
class={cn(
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
className
|
||||
)}
|
||||
>
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-3 w-3" />
|
||||
{/if}
|
||||
|
||||
{label}
|
||||
</span>
|
||||
{/each}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
/**
|
||||
*
|
||||
* BADGES & INDICATORS
|
||||
*
|
||||
* Small visual indicators for status and metadata.
|
||||
*
|
||||
*/
|
||||
|
||||
/** Badge displaying chat statistics (tokens, timing). */
|
||||
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
|
||||
|
||||
/** Generic info badge with optional tooltip and click handler. */
|
||||
export { default as BadgeInfo } from './BadgeInfo.svelte';
|
||||
|
||||
/** Badge indicating model modality (vision, audio, tools). */
|
||||
export { default as BadgeModality } from './BadgeModality.svelte';
|
||||
|
|
@ -27,11 +27,13 @@
|
|||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
initialMessage?: string;
|
||||
isLoading?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onFileUpload?: (files: File[]) => void;
|
||||
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
|
||||
onStop?: () => void;
|
||||
onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
|
||||
showHelperText?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
}
|
||||
|
|
@ -39,11 +41,13 @@
|
|||
let {
|
||||
class: className,
|
||||
disabled = false,
|
||||
initialMessage = '',
|
||||
isLoading = false,
|
||||
onFileRemove,
|
||||
onFileUpload,
|
||||
onSend,
|
||||
onStop,
|
||||
onSystemPromptAdd,
|
||||
showHelperText = true,
|
||||
uploadedFiles = $bindable([])
|
||||
}: Props = $props();
|
||||
|
|
@ -53,15 +57,28 @@
|
|||
let currentConfig = $derived(config());
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state('');
|
||||
let message = $state(initialMessage);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let previousInitialMessage = $state(initialMessage);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
// Sync message when initialMessage prop changes (e.g., after draft restoration)
|
||||
$effect(() => {
|
||||
if (initialMessage !== previousInitialMessage) {
|
||||
message = initialMessage;
|
||||
previousInitialMessage = initialMessage;
|
||||
}
|
||||
});
|
||||
|
||||
function handleSystemPromptClick() {
|
||||
onSystemPromptAdd?.({ message, files: uploadedFiles });
|
||||
}
|
||||
|
||||
// Check if model is selected (in ROUTER mode)
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
|
|
@ -308,6 +325,7 @@
|
|||
onFileUpload={handleFileUpload}
|
||||
onMicClick={handleMicClick}
|
||||
onStop={handleStop}
|
||||
onSystemPromptClick={handleSystemPromptClick}
|
||||
/>
|
||||
</div>
|
||||
</form>
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
<script lang="ts">
|
||||
import { Paperclip } from '@lucide/svelte';
|
||||
import { MessageSquare } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
|
|
@ -11,6 +12,7 @@
|
|||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
|
|
@ -18,7 +20,8 @@
|
|||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
onFileUpload
|
||||
onFileUpload,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
const fileUploadTooltipText = $derived.by(() => {
|
||||
|
|
@ -118,6 +121,23 @@
|
|||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
<DropdownMenu.Separator />
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onSystemPromptClick?.()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
|
||||
<span>System Prompt</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>Add a custom system message for this conversation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
onFileUpload?: () => void;
|
||||
onMicClick?: () => void;
|
||||
onStop?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
|
|
@ -39,7 +40,8 @@
|
|||
uploadedFiles = [],
|
||||
onFileUpload,
|
||||
onMicClick,
|
||||
onStop
|
||||
onStop,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
|
|
@ -170,6 +172,7 @@
|
|||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
|
||||
<ModelsSelector
|
||||
|
|
|
|||
|
|
@ -1,6 +1,15 @@
|
|||
<script lang="ts">
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { goto } from '$app/navigation';
|
||||
import {
|
||||
chatStore,
|
||||
pendingEditMessageId,
|
||||
clearPendingEditMessageId,
|
||||
removeSystemPromptPlaceholder
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { DatabaseService } from '$lib/services';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants/ui';
|
||||
import { copyToClipboard, isIMEComposing, formatMessageForClipboard } from '$lib/utils';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
|
|
@ -92,8 +101,30 @@
|
|||
return null;
|
||||
});
|
||||
|
||||
function handleCancelEdit() {
|
||||
// Auto-start edit mode if this message is the pending edit target
|
||||
$effect(() => {
|
||||
const pendingId = pendingEditMessageId();
|
||||
|
||||
if (pendingId && pendingId === message.id && !isEditing) {
|
||||
handleEdit();
|
||||
clearPendingEditMessageId();
|
||||
}
|
||||
});
|
||||
|
||||
async function handleCancelEdit() {
|
||||
isEditing = false;
|
||||
|
||||
// If canceling a new system message with placeholder content, remove it without deleting children
|
||||
if (message.role === 'system') {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
editedContent = message.content;
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
|
@ -114,8 +145,17 @@
|
|||
onCopy?.(message);
|
||||
}
|
||||
|
||||
function handleConfirmDelete() {
|
||||
onDelete?.(message);
|
||||
async function handleConfirmDelete() {
|
||||
if (message.role === 'system') {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
} else {
|
||||
onDelete?.(message);
|
||||
}
|
||||
|
||||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +166,12 @@
|
|||
|
||||
function handleEdit() {
|
||||
isEditing = true;
|
||||
editedContent = message.content;
|
||||
// Clear placeholder content for system messages
|
||||
editedContent =
|
||||
message.role === 'system' && message.content === SYSTEM_MESSAGE_PLACEHOLDER
|
||||
? ''
|
||||
: message.content;
|
||||
textareaElement?.focus();
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
||||
|
|
@ -166,7 +211,26 @@
|
|||
}
|
||||
|
||||
async function handleSaveEdit() {
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
if (message.role === 'system') {
|
||||
// System messages: update in place without branching
|
||||
const newContent = editedContent.trim();
|
||||
|
||||
// If content is empty or still the placeholder, remove without deleting children
|
||||
if (!newContent) {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
isEditing = false;
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await DatabaseService.updateMessage(message.id, { content: newContent });
|
||||
const index = conversationsStore.findMessageIndex(message.id);
|
||||
if (index !== -1) {
|
||||
conversationsStore.updateMessageAtIndex(index, { content: newContent });
|
||||
}
|
||||
} else if (message.role === 'user') {
|
||||
const finalExtras = await getMergedExtras();
|
||||
onEditWithBranching?.(message, editedContent.trim(), finalExtras);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
ChatMessageBranchingControls,
|
||||
DialogConfirmation
|
||||
} from '$lib/components/app';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
|
||||
interface Props {
|
||||
role: 'user' | 'assistant';
|
||||
|
|
@ -26,6 +27,9 @@
|
|||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
showRawOutputSwitch?: boolean;
|
||||
rawOutputEnabled?: boolean;
|
||||
onRawOutputToggle?: (enabled: boolean) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
|
|
@ -42,7 +46,10 @@
|
|||
onRegenerate,
|
||||
role,
|
||||
siblingInfo = null,
|
||||
showDeleteDialog
|
||||
showDeleteDialog,
|
||||
showRawOutputSwitch = false,
|
||||
rawOutputEnabled = false,
|
||||
onRawOutputToggle
|
||||
}: Props = $props();
|
||||
|
||||
function handleConfirmDelete() {
|
||||
|
|
@ -51,9 +58,9 @@
|
|||
}
|
||||
</script>
|
||||
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-{justify}">
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-between">
|
||||
<div
|
||||
class="absolute top-0 {actionsPosition === 'left'
|
||||
class="{actionsPosition === 'left'
|
||||
? 'left-0'
|
||||
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity"
|
||||
>
|
||||
|
|
@ -81,6 +88,16 @@
|
|||
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if showRawOutputSwitch}
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs text-muted-foreground">Show raw output</span>
|
||||
<Switch
|
||||
checked={rawOutputEnabled}
|
||||
onCheckedChange={(checked) => onRawOutputToggle?.(checked)}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<DialogConfirmation
|
||||
|
|
|
|||
|
|
@ -90,6 +90,9 @@
|
|||
|
||||
const processingState = useProcessingState();
|
||||
|
||||
// Local state for raw output toggle (per message)
|
||||
let showRawOutput = $state(false);
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let displayedModel = $derived((): string | null => {
|
||||
|
|
@ -238,7 +241,7 @@
|
|||
</div>
|
||||
</div>
|
||||
{:else if message.role === 'assistant'}
|
||||
{#if config().disableReasoningFormat}
|
||||
{#if showRawOutput}
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
{:else}
|
||||
<MarkdownContent content={messageContent || ''} />
|
||||
|
|
@ -352,6 +355,9 @@
|
|||
{onConfirmDelete}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
showRawOutputSwitch={currentConfig.showRawOutputSwitch}
|
||||
rawOutputEnabled={showRawOutput}
|
||||
onRawOutputToggle={(enabled) => (showRawOutput = enabled)}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@
|
|||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler
|
||||
} from '$lib/stores/chat.svelte';
|
||||
|
|
@ -71,6 +72,8 @@
|
|||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
|
@ -79,7 +82,7 @@
|
|||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
|
|
@ -221,6 +224,14 @@
|
|||
}
|
||||
}
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
chatStore.savePendingDraft(draft.message, draft.files);
|
||||
}
|
||||
|
||||
await chatStore.addSystemPrompt();
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (disableAutoScroll || !chatScrollContainer) return;
|
||||
|
||||
|
|
@ -343,6 +354,12 @@
|
|||
if (!disableAutoScroll) {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
uploadedFiles = pendingDraft.files;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
|
|
@ -428,11 +445,13 @@
|
|||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
|
||||
<ChatForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
|
|
@ -486,11 +505,13 @@
|
|||
<div in:fly={{ y: 10, duration: 250, delay: hasPropsError ? 0 : 300 }}>
|
||||
<ChatForm
|
||||
disabled={hasPropsError}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={true}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -254,8 +254,13 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
label: 'Show raw LLM output',
|
||||
key: 'disableReasoningParsing',
|
||||
label: 'Disable reasoning content parsing',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showRawOutputSwitch',
|
||||
label: 'Enable raw output toggle',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
<script lang="ts">
|
||||
import ChevronsUpDownIcon from '@lucide/svelte/icons/chevrons-up-down';
|
||||
import * as Collapsible from '$lib/components/ui/collapsible/index.js';
|
||||
import { buttonVariants } from '$lib/components/ui/button/index.js';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
class?: string;
|
||||
icon?: Component;
|
||||
iconClass?: string;
|
||||
title: string;
|
||||
subtitle?: string;
|
||||
isStreaming?: boolean;
|
||||
onToggle?: () => void;
|
||||
children: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
class: className = '',
|
||||
icon: Icon,
|
||||
iconClass = 'h-4 w-4',
|
||||
title,
|
||||
subtitle,
|
||||
isStreaming = false,
|
||||
onToggle,
|
||||
children
|
||||
}: Props = $props();
|
||||
|
||||
let contentContainer: HTMLDivElement | undefined = $state();
|
||||
const autoScroll = createAutoScrollController();
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setContainer(contentContainer);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Only auto-scroll when open and streaming
|
||||
autoScroll.updateInterval(open && isStreaming);
|
||||
});
|
||||
|
||||
function handleScroll() {
|
||||
autoScroll.handleScroll();
|
||||
}
|
||||
</script>
|
||||
|
||||
<Collapsible.Root
|
||||
{open}
|
||||
onOpenChange={(value) => {
|
||||
open = value;
|
||||
onToggle?.();
|
||||
}}
|
||||
class={className}
|
||||
>
|
||||
<Card class="gap-0 border-muted bg-muted/30 py-0">
|
||||
<Collapsible.Trigger class="flex w-full cursor-pointer items-center justify-between p-3">
|
||||
<div class="flex items-center gap-2 text-muted-foreground">
|
||||
{#if Icon}
|
||||
<Icon class={iconClass} />
|
||||
{/if}
|
||||
|
||||
<span class="font-mono text-sm font-medium">{title}</span>
|
||||
|
||||
{#if subtitle}
|
||||
<span class="text-xs italic">{subtitle}</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div
|
||||
class={buttonVariants({
|
||||
variant: 'ghost',
|
||||
size: 'sm',
|
||||
class: 'h-6 w-6 p-0 text-muted-foreground hover:text-foreground'
|
||||
})}
|
||||
>
|
||||
<ChevronsUpDownIcon class="h-4 w-4" />
|
||||
|
||||
<span class="sr-only">Toggle content</span>
|
||||
</div>
|
||||
</Collapsible.Trigger>
|
||||
|
||||
<Collapsible.Content>
|
||||
<div
|
||||
bind:this={contentContainer}
|
||||
class="overflow-y-auto border-t border-muted px-3 pb-3"
|
||||
onscroll={handleScroll}
|
||||
style="min-height: var(--min-message-height); max-height: var(--max-message-height);"
|
||||
>
|
||||
{@render children()}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
</Card>
|
||||
</Collapsible.Root>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,95 @@
|
|||
<script lang="ts">
|
||||
import hljs from 'highlight.js';
|
||||
import { browser } from '$app/environment';
|
||||
import { mode } from 'mode-watcher';
|
||||
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language?: string;
|
||||
class?: string;
|
||||
maxHeight?: string;
|
||||
maxWidth?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
code,
|
||||
language = 'text',
|
||||
class: className = '',
|
||||
maxHeight = '60vh',
|
||||
maxWidth = ''
|
||||
}: Props = $props();
|
||||
|
||||
let highlightedHtml = $state('');
|
||||
|
||||
function loadHighlightTheme(isDark: boolean) {
|
||||
if (!browser) return;
|
||||
|
||||
const existingThemes = document.querySelectorAll('style[data-highlight-theme-preview]');
|
||||
existingThemes.forEach((style) => style.remove());
|
||||
|
||||
const style = document.createElement('style');
|
||||
style.setAttribute('data-highlight-theme-preview', 'true');
|
||||
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const currentMode = mode.current;
|
||||
const isDark = currentMode === 'dark';
|
||||
|
||||
loadHighlightTheme(isDark);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!code) {
|
||||
highlightedHtml = '';
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the language is supported
|
||||
const lang = language.toLowerCase();
|
||||
const isSupported = hljs.getLanguage(lang);
|
||||
|
||||
if (isSupported) {
|
||||
const result = hljs.highlight(code, { language: lang });
|
||||
highlightedHtml = result.value;
|
||||
} else {
|
||||
// Try auto-detection or fallback to plain text
|
||||
const result = hljs.highlightAuto(code);
|
||||
highlightedHtml = result.value;
|
||||
}
|
||||
} catch {
|
||||
// Fallback to escaped plain text
|
||||
highlightedHtml = code.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="code-preview-wrapper rounded-lg border border-border bg-muted {className}"
|
||||
style="max-height: {maxHeight}; max-width: {maxWidth};"
|
||||
>
|
||||
<!-- Needs to be formatted as single line for proper rendering -->
|
||||
<pre class="m-0"><code class="hljs text-sm leading-relaxed">{@html highlightedHtml}</code></pre>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.code-preview-wrapper {
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
.code-preview-wrapper pre {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.code-preview-wrapper code {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
*
|
||||
* CONTENT RENDERING
|
||||
*
|
||||
* Components for rendering rich content: markdown, code, and previews.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **MarkdownContent** - Rich markdown renderer
|
||||
*
|
||||
* Renders markdown content with syntax highlighting, LaTeX math,
|
||||
* tables, links, and code blocks. Optimized for streaming with
|
||||
* incremental block-based rendering.
|
||||
*
|
||||
* **Features:**
|
||||
* - GFM (GitHub Flavored Markdown): tables, task lists, strikethrough
|
||||
* - LaTeX math via KaTeX (`$inline$` and `$$block$$`)
|
||||
* - Syntax highlighting (highlight.js) with language detection
|
||||
* - Code copy buttons with click feedback
|
||||
* - External links open in new tab with security attrs
|
||||
* - Image attachment resolution from message extras
|
||||
* - Dark/light theme support (auto-switching)
|
||||
* - Streaming-optimized incremental rendering
|
||||
* - Code preview dialog for large blocks
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <MarkdownContent content={message.content} attachments={message.extra} />
|
||||
* ```
|
||||
*/
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
/**
|
||||
* **SyntaxHighlightedCode** - Code syntax highlighting
|
||||
*
|
||||
* Renders code with syntax highlighting using highlight.js.
|
||||
* Supports theme switching and scrollable containers.
|
||||
*
|
||||
* **Features:**
|
||||
* - Auto language detection with fallback
|
||||
* - Dark/light theme auto-switching
|
||||
* - Scrollable container with configurable max dimensions
|
||||
* - Monospace font styling
|
||||
* - Preserves whitespace and formatting
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SyntaxHighlightedCode code={jsonString} language="json" />
|
||||
* ```
|
||||
*/
|
||||
export { default as SyntaxHighlightedCode } from './SyntaxHighlightedCode.svelte';
|
||||
|
||||
/**
|
||||
* **CollapsibleContentBlock** - Expandable content card
|
||||
*
|
||||
* Reusable collapsible card with header, icon, and auto-scroll.
|
||||
* Used for tool calls and reasoning blocks in chat messages.
|
||||
*
|
||||
* **Features:**
|
||||
* - Collapsible content with smooth animation
|
||||
* - Custom icon and title display
|
||||
* - Optional subtitle/status text
|
||||
* - Auto-scroll during streaming (pauses on user scroll)
|
||||
* - Configurable max height with overflow scroll
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <CollapsibleContentBlock
|
||||
* bind:open
|
||||
* icon={BrainIcon}
|
||||
* title="Thinking..."
|
||||
* isStreaming={true}
|
||||
* >
|
||||
* {reasoningContent}
|
||||
* </CollapsibleContentBlock>
|
||||
* ```
|
||||
*/
|
||||
export { default as CollapsibleContentBlock } from './CollapsibleContentBlock.svelte';
|
||||
|
|
@ -17,9 +17,13 @@
|
|||
let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(getInitialSelectedIds());
|
||||
let lastClickedId = $state<string | null>(null);
|
||||
|
||||
function getInitialSelectedIds(): SvelteSet<string> {
|
||||
return new SvelteSet(conversations.map((c) => c.id));
|
||||
}
|
||||
|
||||
let filteredConversations = $derived(
|
||||
conversations.filter((conv) => {
|
||||
const name = conv.name || 'Untitled conversation';
|
||||
|
|
@ -92,7 +96,7 @@
|
|||
}
|
||||
|
||||
function handleCancel() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
selectedIds = getInitialSelectedIds();
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
|
||||
|
|
@ -100,7 +104,7 @@
|
|||
}
|
||||
|
||||
export function reset() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
selectedIds = getInitialSelectedIds();
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
<script lang="ts">
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
children?: import('svelte').Snippet;
|
||||
gapSize?: string;
|
||||
onScrollableChange?: (isScrollable: boolean) => void;
|
||||
}
|
||||
|
||||
let { class: className = '', children, gapSize = '3', onScrollableChange }: Props = $props();
|
||||
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
|
||||
function scrollLeft(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollRight(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function updateScrollButtons() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
|
||||
|
||||
canScrollLeft = scrollLeft > 0;
|
||||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
|
||||
|
||||
const isScrollable = scrollWidth > clientWidth;
|
||||
onScrollableChange?.(isScrollable);
|
||||
}
|
||||
|
||||
export function resetScroll() {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollLeft = 0;
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer) {
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="relative {className}">
|
||||
<button
|
||||
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<div
|
||||
class="scrollbar-hide flex items-start gap-{gapSize} overflow-x-auto"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
|
@ -11,7 +11,9 @@
|
|||
|
||||
let baseClasses =
|
||||
'px-1 pointer-events-none inline-flex select-none items-center gap-0.5 font-sans text-md font-medium opacity-0 transition-opacity -my-1';
|
||||
let variantClasses = variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground';
|
||||
let variantClasses = $derived(
|
||||
variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground'
|
||||
);
|
||||
</script>
|
||||
|
||||
<kbd class="{baseClasses} {variantClasses} {className}">
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
<script lang="ts">
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
text: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { text, class: className = '' }: Props = $props();
|
||||
|
||||
let textElement: HTMLSpanElement | undefined = $state();
|
||||
let isTruncated = $state(false);
|
||||
|
||||
function checkTruncation() {
|
||||
if (textElement) {
|
||||
isTruncated = textElement.scrollWidth > textElement.clientWidth;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (textElement) {
|
||||
checkTruncation();
|
||||
|
||||
const observer = new ResizeObserver(checkTruncation);
|
||||
observer.observe(textElement);
|
||||
|
||||
return () => observer.disconnect();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if isTruncated}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class={className}>
|
||||
<span bind:this={textElement} class="block truncate">
|
||||
{text}
|
||||
</span>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>{text}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span bind:this={textElement} class="{className} block truncate">
|
||||
{text}
|
||||
</span>
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
*
|
||||
* MISC
|
||||
*
|
||||
* Miscellaneous utility components.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **ConversationSelection** - Multi-select conversation picker
|
||||
*
|
||||
* List of conversations with checkboxes for multi-selection.
|
||||
* Used in import/export dialogs for selecting conversations.
|
||||
*
|
||||
* **Features:**
|
||||
* - Search/filter conversations by name
|
||||
* - Select all / deselect all controls
|
||||
* - Shift-click for range selection
|
||||
* - Message count display per conversation
|
||||
* - Mode-specific UI (export vs import)
|
||||
*/
|
||||
export { default as ConversationSelection } from './ConversationSelection.svelte';
|
||||
|
||||
/**
|
||||
* Horizontal scrollable carousel with navigation arrows.
|
||||
* Used for displaying items in a horizontally scrollable container
|
||||
* with left/right navigation buttons that appear on hover.
|
||||
*/
|
||||
export { default as HorizontalScrollCarousel } from './HorizontalScrollCarousel.svelte';
|
||||
|
||||
/**
|
||||
* **TruncatedText** - Text with ellipsis and tooltip
|
||||
*
|
||||
* Displays text with automatic truncation and full content in tooltip.
|
||||
* Useful for long names or paths in constrained spaces.
|
||||
*/
|
||||
export { default as TruncatedText } from './TruncatedText.svelte';
|
||||
|
||||
/**
|
||||
* **KeyboardShortcutInfo** - Keyboard shortcut hint display
|
||||
*
|
||||
* Displays keyboard shortcut hints (e.g., "⌘ + Enter").
|
||||
* Supports special keys like shift, cmd, and custom text.
|
||||
*/
|
||||
export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
<script lang="ts">
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface ActionItem {
|
||||
icon: Component;
|
||||
label: string;
|
||||
onclick: (event: Event) => void;
|
||||
variant?: 'default' | 'destructive';
|
||||
disabled?: boolean;
|
||||
shortcut?: string[];
|
||||
separator?: boolean;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
triggerIcon: Component;
|
||||
triggerTooltip?: string;
|
||||
triggerClass?: string;
|
||||
actions: ActionItem[];
|
||||
align?: 'start' | 'center' | 'end';
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
triggerIcon,
|
||||
triggerTooltip,
|
||||
triggerClass = '',
|
||||
actions,
|
||||
align = 'end',
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Root bind:open>
|
||||
<DropdownMenu.Trigger
|
||||
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{#if triggerTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
<span class="sr-only">{triggerTooltip}</span>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{triggerTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content {align} class="z-[999999] w-48">
|
||||
{#each actions as action, index (action.label)}
|
||||
{#if action.separator && index > 0}
|
||||
<DropdownMenu.Separator />
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Item
|
||||
onclick={action.onclick}
|
||||
variant={action.variant}
|
||||
disabled={action.disabled}
|
||||
class="flex items-center justify-between hover:[&>kbd]:opacity-100"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render iconComponent(
|
||||
action.icon,
|
||||
`h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}`
|
||||
)}
|
||||
{action.label}
|
||||
</div>
|
||||
|
||||
{#if action.shortcut}
|
||||
<KeyboardShortcutInfo keys={action.shortcut} variant={action.variant} />
|
||||
{/if}
|
||||
</DropdownMenu.Item>
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
{#snippet iconComponent(IconComponent: Component, className: string)}
|
||||
<IconComponent class={className} />
|
||||
{/snippet}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
<script lang="ts">
|
||||
import type { Snippet } from 'svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
placeholder?: string;
|
||||
searchValue?: string;
|
||||
onSearchChange?: (value: string) => void;
|
||||
onSearchKeyDown?: (event: KeyboardEvent) => void;
|
||||
emptyMessage?: string;
|
||||
isEmpty?: boolean;
|
||||
children: Snippet;
|
||||
footer?: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
placeholder = 'Search...',
|
||||
searchValue = $bindable(''),
|
||||
onSearchChange,
|
||||
onSearchKeyDown,
|
||||
emptyMessage = 'No items found',
|
||||
isEmpty = false,
|
||||
children,
|
||||
footer
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="sticky top-0 z-10 mb-2 bg-popover p-1 pt-2">
|
||||
<SearchInput
|
||||
{placeholder}
|
||||
bind:value={searchValue}
|
||||
onInput={onSearchChange}
|
||||
onKeyDown={onSearchKeyDown}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="overflow-y-auto">
|
||||
{@render children()}
|
||||
|
||||
{#if isEmpty}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">{emptyMessage}</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if footer}
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
{@render footer()}
|
||||
{/if}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
*
|
||||
* NAVIGATION & MENUS
|
||||
*
|
||||
* Components for dropdown menus and action selection.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **DropdownMenuSearchable** - Searchable content for dropdown menus
|
||||
*
|
||||
* Renders a search input with filtered content area, empty state, and optional footer.
|
||||
* Designed to be injected into any dropdown container (DropdownMenu.Content,
|
||||
* DropdownMenu.SubContent, etc.) without providing its own Root.
|
||||
*
|
||||
* **Features:**
|
||||
* - Search/filter input
|
||||
* - Keyboard navigation support
|
||||
* - Custom content and footer via snippets
|
||||
* - Empty state message
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <DropdownMenu.Root>
|
||||
* <DropdownMenu.Trigger>...</DropdownMenu.Trigger>
|
||||
* <DropdownMenu.Content class="pt-0">
|
||||
* <DropdownMenuSearchable
|
||||
* bind:searchValue
|
||||
* placeholder="Search..."
|
||||
* isEmpty={filteredItems.length === 0}
|
||||
* >
|
||||
* {#each items as item}<Item {item} />{/each}
|
||||
* </DropdownMenuSearchable>
|
||||
* </DropdownMenu.Content>
|
||||
* </DropdownMenu.Root>
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svelte';
|
||||
|
||||
/**
|
||||
* **DropdownMenuActions** - Multi-action dropdown menu
|
||||
*
|
||||
* Dropdown menu for multiple action options with icons and shortcuts.
|
||||
* Supports destructive variants and keyboard shortcut hints.
|
||||
*
|
||||
* **Features:**
|
||||
* - Configurable trigger icon with tooltip
|
||||
* - Action items with icons and labels
|
||||
* - Destructive variant styling
|
||||
* - Keyboard shortcut display
|
||||
* - Separator support between groups
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <DropdownMenuActions
|
||||
* triggerIcon={MoreHorizontal}
|
||||
* triggerTooltip="More actions"
|
||||
* actions={[
|
||||
* { icon: Edit, label: 'Edit', onclick: handleEdit },
|
||||
* { icon: Trash, label: 'Delete', onclick: handleDelete, variant: 'destructive' }
|
||||
* ]}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
import { serverStore, serverLoading } from '$lib/stores/server.svelte';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { fade, fly, scale } from 'svelte/transition';
|
||||
import { KeyboardKey } from '$lib/enums/keyboard';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
|
|
@ -117,7 +118,7 @@
|
|||
}
|
||||
|
||||
function handleApiKeyKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
if (event.key === KeyboardKey.ENTER) {
|
||||
handleSaveApiKey();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
{model || 'Unknown Model'}
|
||||
</Badge>
|
||||
|
||||
{#if serverData.default_generation_settings.n_ctx}
|
||||
{#if serverData?.default_generation_settings?.n_ctx}
|
||||
<Badge variant="secondary" class="text-xs">
|
||||
ctx: {serverData.default_generation_settings.n_ctx.toLocaleString()}
|
||||
</Badge>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
*
|
||||
* SERVER
|
||||
*
|
||||
* Components for displaying server connection state and handling
|
||||
* connection errors. Integrates with serverStore for state management.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **ServerStatus** - Server connection status indicator
|
||||
*
|
||||
* Compact status display showing connection state, model name,
|
||||
* and context size. Used in headers and loading screens.
|
||||
*
|
||||
* **Architecture:**
|
||||
* - Reads state from serverStore (props, loading, error)
|
||||
* - Displays model name from modelsStore
|
||||
*
|
||||
* **Features:**
|
||||
* - Status dot: green (connected), yellow (connecting), red (error), gray (unknown)
|
||||
* - Status text label
|
||||
* - Model name badge with icon
|
||||
* - Context size badge
|
||||
* - Optional error action button
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <ServerStatus showActions />
|
||||
* ```
|
||||
*/
|
||||
export { default as ServerStatus } from './ServerStatus.svelte';
|
||||
|
||||
/**
|
||||
* **ServerErrorSplash** - Full-screen connection error display
|
||||
*
|
||||
* Blocking error screen shown when server connection fails.
|
||||
* Provides retry options and API key input for authentication errors.
|
||||
*
|
||||
* **Architecture:**
|
||||
* - Detects access denied errors for API key flow
|
||||
* - Validates API key against server before saving
|
||||
* - Integrates with settingsStore for API key persistence
|
||||
*
|
||||
* **Features:**
|
||||
* - Error message display with icon
|
||||
* - Retry connection button with loading state
|
||||
* - API key input for authentication errors
|
||||
* - API key validation with success/error feedback
|
||||
* - Troubleshooting section with server start commands
|
||||
* - Animated transitions for UI elements
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <ServerErrorSplash
|
||||
* error={serverError}
|
||||
* onRetry={handleRetry}
|
||||
* showTroubleshooting
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as ServerErrorSplash } from './ServerErrorSplash.svelte';
|
||||
|
||||
/**
|
||||
* **ServerLoadingSplash** - Full-screen loading display
|
||||
*
|
||||
* Shown during initial server connection. Displays loading animation
|
||||
* with ServerStatus component for real-time connection state.
|
||||
*
|
||||
* **Features:**
|
||||
* - Animated server icon
|
||||
* - Customizable loading message
|
||||
* - Embedded ServerStatus for live updates
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <ServerLoadingSplash message="Connecting to server..." />
|
||||
* ```
|
||||
*/
|
||||
export { default as ServerLoadingSplash } from './ServerLoadingSplash.svelte';
|
||||
|
|
@ -42,7 +42,7 @@
|
|||
bind:this={ref}
|
||||
data-slot="badge"
|
||||
{href}
|
||||
class={cn(badgeVariants({ variant }), className)}
|
||||
class={cn(badgeVariants({ variant }), className, 'backdrop-blur-sm')}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@
|
|||
'bg-destructive shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60 text-white',
|
||||
outline:
|
||||
'bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50 border',
|
||||
secondary: 'bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80',
|
||||
ghost: 'hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50',
|
||||
secondary:
|
||||
'dark:bg-secondary dark:text-secondary-foreground bg-background shadow-sm text-foreground hover:bg-muted-foreground/20',
|
||||
ghost: 'hover:text-accent-foreground hover:bg-muted-foreground/10',
|
||||
link: 'text-primary underline-offset-4 hover:underline'
|
||||
},
|
||||
size: {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils';
|
||||
import { BOX_BORDER } from '$lib/constants/css-classes';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
|
|
@ -14,7 +15,8 @@
|
|||
bind:this={ref}
|
||||
data-slot="card"
|
||||
class={cn(
|
||||
'flex flex-col gap-6 rounded-xl border bg-card py-6 text-card-foreground shadow-sm',
|
||||
'flex flex-col gap-6 rounded-xl bg-card py-6 text-card-foreground shadow-sm',
|
||||
BOX_BORDER,
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
data-slot="dropdown-menu-content"
|
||||
{sideOffset}
|
||||
class={cn(
|
||||
'z-50 max-h-(--bits-dropdown-menu-content-available-height) min-w-[8rem] origin-(--bits-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border border-border bg-popover p-1 text-popover-foreground shadow-md outline-none data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 dark:border-border/20',
|
||||
'z-50 max-h-(--bits-dropdown-menu-content-available-height) min-w-[8rem] origin-(--bits-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border border-border bg-popover p-1.5 text-popover-foreground shadow-md outline-none data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 dark:border-border/20',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
'aria-invalid:border-destructive aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40',
|
||||
className
|
||||
)}
|
||||
style="backdrop-filter: blur(0.5rem);"
|
||||
{type}
|
||||
bind:value
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button/index.js';
|
||||
import { cn } from '$lib/components/ui/utils.js';
|
||||
import PanelLeftIcon from '@lucide/svelte/icons/panel-left';
|
||||
import type { ComponentProps } from 'svelte';
|
||||
import { useSidebar } from './context.svelte.js';
|
||||
|
|
@ -22,7 +21,7 @@
|
|||
data-slot="sidebar-trigger"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class={cn('size-7', className)}
|
||||
class="rounded-full backdrop-blur-lg {className} h-9! w-9!"
|
||||
type="button"
|
||||
onclick={(e) => {
|
||||
onclick?.(e);
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
bind:checked
|
||||
data-slot="switch"
|
||||
class={cn(
|
||||
'peer inline-flex h-[1.15rem] w-8 shrink-0 items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
|
||||
'peer inline-flex h-[1.15rem] w-8 shrink-0 cursor-pointer items-center rounded-full border border-transparent shadow-xs transition-all outline-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=unchecked]:bg-input dark:data-[state=unchecked]:bg-input/80',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -9,22 +9,28 @@
|
|||
side = 'top',
|
||||
children,
|
||||
arrowClasses,
|
||||
noPortal = false,
|
||||
...restProps
|
||||
}: TooltipPrimitive.ContentProps & {
|
||||
arrowClasses?: string;
|
||||
noPortal?: boolean;
|
||||
} = $props();
|
||||
|
||||
const contentClass = $derived(
|
||||
cn(
|
||||
'z-50 w-fit origin-(--bits-tooltip-content-transform-origin) animate-in rounded-md bg-primary px-3 py-1.5 text-xs text-balance text-primary-foreground fade-in-0 zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95',
|
||||
className
|
||||
)
|
||||
);
|
||||
</script>
|
||||
|
||||
<TooltipPrimitive.Portal>
|
||||
{#snippet tooltipContent()}
|
||||
<TooltipPrimitive.Content
|
||||
bind:ref
|
||||
data-slot="tooltip-content"
|
||||
{sideOffset}
|
||||
{side}
|
||||
class={cn(
|
||||
'z-50 w-fit origin-(--bits-tooltip-content-transform-origin) animate-in rounded-md bg-primary px-3 py-1.5 text-xs text-balance text-primary-foreground fade-in-0 zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95',
|
||||
className
|
||||
)}
|
||||
class={contentClass}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
|
|
@ -44,4 +50,12 @@
|
|||
{/snippet}
|
||||
</TooltipPrimitive.Arrow>
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
{/snippet}
|
||||
|
||||
{#if noPortal}
|
||||
{@render tooltipContent()}
|
||||
{:else}
|
||||
<TooltipPrimitive.Portal>
|
||||
{@render tooltipContent()}
|
||||
</TooltipPrimitive.Portal>
|
||||
{/if}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
export interface BinaryDetectionOptions {
|
||||
/** Number of characters to check from the beginning of the file */
|
||||
prefixLength: number;
|
||||
/** Maximum ratio of suspicious characters allowed (0.0 to 1.0) */
|
||||
suspiciousCharThresholdRatio: number;
|
||||
/** Maximum absolute number of null bytes allowed */
|
||||
maxAbsoluteNullBytes: number;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
export const INITIAL_FILE_SIZE = 0;
|
||||
export const PROMPT_CONTENT_SEPARATOR = '\n\n';
|
||||
export const CLIPBOARD_CONTENT_QUOTE_PREFIX = '"';
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
export const CODE_BLOCK_SCROLL_CONTAINER_CLASS = 'code-block-scroll-container';
|
||||
export const CODE_BLOCK_WRAPPER_CLASS = 'code-block-wrapper';
|
||||
export const CODE_BLOCK_HEADER_CLASS = 'code-block-header';
|
||||
export const CODE_BLOCK_ACTIONS_CLASS = 'code-block-actions';
|
||||
export const CODE_LANGUAGE_CLASS = 'code-language';
|
||||
export const COPY_CODE_BTN_CLASS = 'copy-code-btn';
|
||||
export const PREVIEW_CODE_BTN_CLASS = 'preview-code-btn';
|
||||
export const RELATIVE_CLASS = 'relative';
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
export const NEWLINE = '\n';
|
||||
export const DEFAULT_LANGUAGE = 'text';
|
||||
export const LANG_PATTERN = /^(\w*)\n?/;
|
||||
export const AMPERSAND_REGEX = /&/g;
|
||||
export const LT_REGEX = /</g;
|
||||
export const GT_REGEX = />/g;
|
||||
export const FENCE_PATTERN = /^```|\n```/g;
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
export const BOX_BORDER =
|
||||
'border border-border/30 focus-within:border-border dark:border-border/20 dark:focus-within:border-border';
|
||||
|
||||
export const INPUT_CLASSES = `
|
||||
bg-muted/60 dark:bg-muted/75
|
||||
${BOX_BORDER}
|
||||
shadow-sm
|
||||
outline-none
|
||||
text-foreground
|
||||
`;
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
export const MS_PER_SECOND = 1000;
|
||||
export const SECONDS_PER_MINUTE = 60;
|
||||
export const SECONDS_PER_HOUR = 3600;
|
||||
export const SHORT_DURATION_THRESHOLD = 1;
|
||||
export const MEDIUM_DURATION_THRESHOLD = 10;
|
||||
|
||||
/** Default display value when no performance time is available */
|
||||
export const DEFAULT_PERFORMANCE_TIME = '0s';
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
export const IMAGE_NOT_ERROR_BOUND_SELECTOR = 'img:not([data-error-bound])';
|
||||
export const DATA_ERROR_BOUND_ATTR = 'errorBound';
|
||||
export const DATA_ERROR_HANDLED_ATTR = 'errorHandled';
|
||||
export const BOOL_TRUE_STRING = 'true';
|
||||
|
|
@ -1 +1,8 @@
|
|||
export const PROCESSING_INFO_TIMEOUT = 2000;
|
||||
|
||||
/**
|
||||
* Statistics units labels
|
||||
*/
|
||||
export const STATS_UNITS = {
|
||||
TOKENS_PER_SECOND: 't/s'
|
||||
} as const;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
theme: 'system',
|
||||
showThoughtInProgress: false,
|
||||
showToolCalls: false,
|
||||
disableReasoningFormat: false,
|
||||
disableReasoningParsing: false,
|
||||
showRawOutputSwitch: false,
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
askForTitleConfirmation: false,
|
||||
|
|
@ -92,8 +93,10 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
showThoughtInProgress: 'Expand thought process by default when generating messages.',
|
||||
showToolCalls:
|
||||
'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.',
|
||||
disableReasoningFormat:
|
||||
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
|
||||
disableReasoningParsing:
|
||||
'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field',
|
||||
showRawOutputSwitch:
|
||||
'Show toggle button to display messages as plain text instead of Markdown-formatted content',
|
||||
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
|
||||
showMessageStats:
|
||||
'Display generation statistics (tokens/second, token count, duration) below each assistant message.',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* List of all numeric fields in settings configuration.
|
||||
* These fields will be converted from strings to numbers during save.
|
||||
*/
|
||||
export const NUMERIC_FIELDS = [
|
||||
'temperature',
|
||||
'top_k',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'max_tokens',
|
||||
'pasteLongTextToFileLen',
|
||||
'dynatemp_range',
|
||||
'dynatemp_exponent',
|
||||
'typ_p',
|
||||
'xtc_probability',
|
||||
'xtc_threshold',
|
||||
'repeat_last_n',
|
||||
'repeat_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_penalty_last_n',
|
||||
'agenticMaxTurns',
|
||||
'agenticMaxToolPreviewLines'
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Fields that must be positive integers (>= 1).
|
||||
* These will be clamped to minimum 1 and rounded during save.
|
||||
*/
|
||||
export const POSITIVE_INTEGER_FIELDS = ['agenticMaxTurns', 'agenticMaxToolPreviewLines'] as const;
|
||||
|
|
@ -1 +1 @@
|
|||
export const TOOLTIP_DELAY_DURATION = 100;
|
||||
export const TOOLTIP_DELAY_DURATION = 500;
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message';
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import { getContext, setContext } from 'svelte';
|
||||
|
||||
export interface ChatActionsContext {
|
||||
copy: (message: DatabaseMessage) => void;
|
||||
delete: (message: DatabaseMessage) => void;
|
||||
navigateToSibling: (siblingId: string) => void;
|
||||
editWithBranching: (
|
||||
message: DatabaseMessage,
|
||||
newContent: string,
|
||||
newExtras?: DatabaseMessageExtra[]
|
||||
) => void;
|
||||
editWithReplacement: (
|
||||
message: DatabaseMessage,
|
||||
newContent: string,
|
||||
shouldBranch: boolean
|
||||
) => void;
|
||||
editUserMessagePreserveResponses: (
|
||||
message: DatabaseMessage,
|
||||
newContent: string,
|
||||
newExtras?: DatabaseMessageExtra[]
|
||||
) => void;
|
||||
regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void;
|
||||
continueAssistantMessage: (message: DatabaseMessage) => void;
|
||||
}
|
||||
|
||||
const CHAT_ACTIONS_KEY = Symbol.for('chat-actions');
|
||||
|
||||
export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext {
|
||||
return setContext(CHAT_ACTIONS_KEY, ctx);
|
||||
}
|
||||
|
||||
export function getChatActionsContext(): ChatActionsContext {
|
||||
return getContext(CHAT_ACTIONS_KEY);
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
export {
|
||||
getMessageEditContext,
|
||||
setMessageEditContext,
|
||||
type MessageEditContext,
|
||||
type MessageEditState,
|
||||
type MessageEditActions
|
||||
} from './message-edit.context';
|
||||
|
||||
export {
|
||||
getChatActionsContext,
|
||||
setChatActionsContext,
|
||||
type ChatActionsContext
|
||||
} from './chat-actions.context';
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
import { getContext, setContext } from 'svelte';
|
||||
|
||||
export interface MessageEditState {
|
||||
readonly isEditing: boolean;
|
||||
readonly editedContent: string;
|
||||
readonly editedExtras: DatabaseMessageExtra[];
|
||||
readonly editedUploadedFiles: ChatUploadedFile[];
|
||||
readonly originalContent: string;
|
||||
readonly originalExtras: DatabaseMessageExtra[];
|
||||
readonly showSaveOnlyOption: boolean;
|
||||
}
|
||||
|
||||
export interface MessageEditActions {
|
||||
setContent: (content: string) => void;
|
||||
setExtras: (extras: DatabaseMessageExtra[]) => void;
|
||||
setUploadedFiles: (files: ChatUploadedFile[]) => void;
|
||||
save: () => void;
|
||||
saveOnly: () => void;
|
||||
cancel: () => void;
|
||||
startEdit: () => void;
|
||||
}
|
||||
|
||||
export type MessageEditContext = MessageEditState & MessageEditActions;
|
||||
|
||||
const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit');
|
||||
|
||||
/**
|
||||
* Sets the message edit context. Call this in the parent component (ChatMessage.svelte).
|
||||
*/
|
||||
export function setMessageEditContext(ctx: MessageEditContext): MessageEditContext {
|
||||
return setContext(MESSAGE_EDIT_KEY, ctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the message edit context. Call this in child components.
|
||||
*/
|
||||
export function getMessageEditContext(): MessageEditContext {
|
||||
return getContext(MESSAGE_EDIT_KEY);
|
||||
}
|
||||
|
|
@ -1,4 +1,51 @@
|
|||
export enum ChatMessageStatsView {
|
||||
GENERATION = 'generation',
|
||||
READING = 'reading'
|
||||
READING = 'reading',
|
||||
TOOLS = 'tools',
|
||||
SUMMARY = 'summary'
|
||||
}
|
||||
|
||||
/**
|
||||
* Reasoning format options for API requests.
|
||||
*/
|
||||
export enum ReasoningFormat {
|
||||
NONE = 'none',
|
||||
AUTO = 'auto'
|
||||
}
|
||||
|
||||
/**
|
||||
* Message roles for chat messages.
|
||||
*/
|
||||
export enum MessageRole {
|
||||
USER = 'user',
|
||||
ASSISTANT = 'assistant',
|
||||
SYSTEM = 'system',
|
||||
TOOL = 'tool'
|
||||
}
|
||||
|
||||
/**
|
||||
* Message types for different content kinds.
|
||||
*/
|
||||
export enum MessageType {
|
||||
ROOT = 'root',
|
||||
TEXT = 'text',
|
||||
THINK = 'think',
|
||||
SYSTEM = 'system'
|
||||
}
|
||||
|
||||
/**
|
||||
* Content part types for API chat message content.
|
||||
*/
|
||||
export enum ContentPartType {
|
||||
TEXT = 'text',
|
||||
IMAGE_URL = 'image_url',
|
||||
INPUT_AUDIO = 'input_audio'
|
||||
}
|
||||
|
||||
/**
|
||||
* Error dialog types for displaying server/timeout errors.
|
||||
*/
|
||||
export enum ErrorDialogType {
|
||||
TIMEOUT = 'timeout',
|
||||
SERVER = 'server'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
/**
|
||||
* Keyboard key names for event handling
|
||||
*/
|
||||
export enum KeyboardKey {
|
||||
ENTER = 'Enter',
|
||||
ESCAPE = 'Escape',
|
||||
ARROW_UP = 'ArrowUp',
|
||||
ARROW_DOWN = 'ArrowDown',
|
||||
TAB = 'Tab',
|
||||
D_LOWER = 'd',
|
||||
D_UPPER = 'D',
|
||||
E_UPPER = 'E',
|
||||
K_LOWER = 'k',
|
||||
O_UPPER = 'O'
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Parameter source - indicates whether a parameter uses default or custom value
|
||||
*/
|
||||
export enum ParameterSource {
|
||||
DEFAULT = 'default',
|
||||
CUSTOM = 'custom'
|
||||
}
|
||||
|
||||
/**
|
||||
* Syncable parameter type - data types for parameters that can be synced with server
|
||||
*/
|
||||
export enum SyncableParameterType {
|
||||
NUMBER = 'number',
|
||||
STRING = 'string',
|
||||
BOOLEAN = 'boolean'
|
||||
}
|
||||
|
||||
/**
|
||||
* Settings field type - defines the input type for settings fields
|
||||
*/
|
||||
export enum SettingsFieldType {
|
||||
INPUT = 'input',
|
||||
TEXTAREA = 'textarea',
|
||||
CHECKBOX = 'checkbox',
|
||||
SELECT = 'select'
|
||||
}
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants/auto-scroll';
|
||||
|
||||
export interface AutoScrollOptions {
|
||||
/** Whether auto-scroll is disabled globally (e.g., from settings) */
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an auto-scroll controller for a scrollable container.
|
||||
*
|
||||
* Features:
|
||||
* - Auto-scrolls to bottom during streaming/loading
|
||||
* - Stops auto-scroll when user manually scrolls up
|
||||
* - Resumes auto-scroll when user scrolls back to bottom
|
||||
*/
|
||||
export class AutoScrollController {
|
||||
private _autoScrollEnabled = $state(true);
|
||||
private _userScrolledUp = $state(false);
|
||||
private _lastScrollTop = $state(0);
|
||||
private _scrollInterval: ReturnType<typeof setInterval> | undefined;
|
||||
private _scrollTimeout: ReturnType<typeof setTimeout> | undefined;
|
||||
private _container: HTMLElement | undefined;
|
||||
private _disabled: boolean;
|
||||
|
||||
constructor(options: AutoScrollOptions = {}) {
|
||||
this._disabled = options.disabled ?? false;
|
||||
}
|
||||
|
||||
get autoScrollEnabled(): boolean {
|
||||
return this._autoScrollEnabled;
|
||||
}
|
||||
|
||||
get userScrolledUp(): boolean {
|
||||
return this._userScrolledUp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Binds the controller to a scrollable container element.
|
||||
*/
|
||||
setContainer(container: HTMLElement | undefined): void {
|
||||
this._container = container;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the disabled state.
|
||||
*/
|
||||
setDisabled(disabled: boolean): void {
|
||||
this._disabled = disabled;
|
||||
if (disabled) {
|
||||
this._autoScrollEnabled = false;
|
||||
this.stopInterval();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles scroll events to detect user scroll direction and toggle auto-scroll.
|
||||
*/
|
||||
handleScroll(): void {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
const { scrollTop, scrollHeight, clientHeight } = this._container;
|
||||
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
|
||||
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
|
||||
|
||||
if (scrollTop < this._lastScrollTop && !isAtBottom) {
|
||||
this._userScrolledUp = true;
|
||||
this._autoScrollEnabled = false;
|
||||
} else if (isAtBottom && this._userScrolledUp) {
|
||||
this._userScrolledUp = false;
|
||||
this._autoScrollEnabled = true;
|
||||
}
|
||||
|
||||
if (this._scrollTimeout) {
|
||||
clearTimeout(this._scrollTimeout);
|
||||
}
|
||||
|
||||
this._scrollTimeout = setTimeout(() => {
|
||||
if (isAtBottom) {
|
||||
this._userScrolledUp = false;
|
||||
this._autoScrollEnabled = true;
|
||||
}
|
||||
}, AUTO_SCROLL_INTERVAL);
|
||||
|
||||
this._lastScrollTop = scrollTop;
|
||||
}
|
||||
|
||||
/**
|
||||
* Scrolls the container to the bottom.
|
||||
*/
|
||||
scrollToBottom(behavior: ScrollBehavior = 'smooth'): void {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
this._container.scrollTo({
|
||||
top: this._container.scrollHeight,
|
||||
behavior
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables auto-scroll (e.g., when user sends a message).
|
||||
*/
|
||||
enable(): void {
|
||||
if (this._disabled) return;
|
||||
this._userScrolledUp = false;
|
||||
this._autoScrollEnabled = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts the auto-scroll interval for continuous scrolling during streaming.
|
||||
*/
|
||||
startInterval(): void {
|
||||
if (this._disabled || this._scrollInterval) return;
|
||||
|
||||
this._scrollInterval = setInterval(() => {
|
||||
this.scrollToBottom();
|
||||
}, AUTO_SCROLL_INTERVAL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the auto-scroll interval.
|
||||
*/
|
||||
stopInterval(): void {
|
||||
if (this._scrollInterval) {
|
||||
clearInterval(this._scrollInterval);
|
||||
this._scrollInterval = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the auto-scroll interval based on streaming state.
|
||||
* Call this in a $effect to automatically manage the interval.
|
||||
*/
|
||||
updateInterval(isStreaming: boolean): void {
|
||||
if (this._disabled) {
|
||||
this.stopInterval();
|
||||
return;
|
||||
}
|
||||
|
||||
if (isStreaming && this._autoScrollEnabled) {
|
||||
if (!this._scrollInterval) {
|
||||
this.startInterval();
|
||||
}
|
||||
} else {
|
||||
this.stopInterval();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up resources. Call this in onDestroy or when the component unmounts.
|
||||
*/
|
||||
destroy(): void {
|
||||
this.stopInterval();
|
||||
if (this._scrollTimeout) {
|
||||
clearTimeout(this._scrollTimeout);
|
||||
this._scrollTimeout = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new AutoScrollController instance.
|
||||
*/
|
||||
export function createAutoScrollController(options: AutoScrollOptions = {}): AutoScrollController {
|
||||
return new AutoScrollController(options);
|
||||
}
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
import { activeProcessingState } from '$lib/stores/chat.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { STATS_UNITS } from '$lib/constants/processing-info';
|
||||
import type { ApiProcessingState } from '$lib/types';
|
||||
|
||||
export interface LiveProcessingStats {
|
||||
interface LiveProcessingStats {
|
||||
tokensProcessed: number;
|
||||
totalTokens: number;
|
||||
timeMs: number;
|
||||
|
|
@ -9,7 +11,7 @@ export interface LiveProcessingStats {
|
|||
etaSecs?: number;
|
||||
}
|
||||
|
||||
export interface LiveGenerationStats {
|
||||
interface LiveGenerationStats {
|
||||
tokensGenerated: number;
|
||||
timeMs: number;
|
||||
tokensPerSecond: number;
|
||||
|
|
@ -18,6 +20,7 @@ export interface LiveGenerationStats {
|
|||
export interface UseProcessingStateReturn {
|
||||
readonly processingState: ApiProcessingState | null;
|
||||
getProcessingDetails(): string[];
|
||||
getTechnicalDetails(): string[];
|
||||
getProcessingMessage(): string;
|
||||
getPromptProgressText(): string | null;
|
||||
getLiveProcessingStats(): LiveProcessingStats | null;
|
||||
|
|
@ -138,8 +141,31 @@ export function useProcessingState(): UseProcessingStateReturn {
|
|||
|
||||
const details: string[] = [];
|
||||
|
||||
// Show prompt processing progress with ETA during preparation phase
|
||||
if (stateToUse.promptProgress) {
|
||||
const { processed, total, time_ms, cache } = stateToUse.promptProgress;
|
||||
const actualProcessed = processed - cache;
|
||||
const actualTotal = total - cache;
|
||||
|
||||
if (actualProcessed < actualTotal && actualProcessed > 0) {
|
||||
const percent = Math.round((actualProcessed / actualTotal) * 100);
|
||||
const eta = getETASecs(actualProcessed, actualTotal, time_ms);
|
||||
|
||||
if (eta !== undefined) {
|
||||
const etaSecs = Math.ceil(eta);
|
||||
details.push(`Processing ${percent}% (ETA: ${etaSecs}s)`);
|
||||
} else {
|
||||
details.push(`Processing ${percent}%`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Always show context info when we have valid data
|
||||
if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) {
|
||||
if (
|
||||
typeof stateToUse.contextTotal === 'number' &&
|
||||
stateToUse.contextUsed >= 0 &&
|
||||
stateToUse.contextTotal > 0
|
||||
) {
|
||||
const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100);
|
||||
|
||||
details.push(
|
||||
|
|
@ -163,7 +189,57 @@ export function useProcessingState(): UseProcessingStateReturn {
|
|||
}
|
||||
|
||||
if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) {
|
||||
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`);
|
||||
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`);
|
||||
}
|
||||
|
||||
if (stateToUse.speculative) {
|
||||
details.push('Speculative decoding enabled');
|
||||
}
|
||||
|
||||
return details;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns technical details without the progress message (for bottom bar)
|
||||
*/
|
||||
function getTechnicalDetails(): string[] {
|
||||
const stateToUse = processingState || lastKnownState;
|
||||
if (!stateToUse) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const details: string[] = [];
|
||||
|
||||
// Always show context info when we have valid data
|
||||
if (
|
||||
typeof stateToUse.contextTotal === 'number' &&
|
||||
stateToUse.contextUsed >= 0 &&
|
||||
stateToUse.contextTotal > 0
|
||||
) {
|
||||
const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100);
|
||||
|
||||
details.push(
|
||||
`Context: ${stateToUse.contextUsed}/${stateToUse.contextTotal} (${contextPercent}%)`
|
||||
);
|
||||
}
|
||||
|
||||
if (stateToUse.outputTokensUsed > 0) {
|
||||
// Handle infinite max_tokens (-1) case
|
||||
if (stateToUse.outputTokensMax <= 0) {
|
||||
details.push(`Output: ${stateToUse.outputTokensUsed}/∞`);
|
||||
} else {
|
||||
const outputPercent = Math.round(
|
||||
(stateToUse.outputTokensUsed / stateToUse.outputTokensMax) * 100
|
||||
);
|
||||
|
||||
details.push(
|
||||
`Output: ${stateToUse.outputTokensUsed}/${stateToUse.outputTokensMax} (${outputPercent}%)`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) {
|
||||
details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`);
|
||||
}
|
||||
|
||||
if (stateToUse.speculative) {
|
||||
|
|
@ -251,6 +327,7 @@ export function useProcessingState(): UseProcessingStateReturn {
|
|||
return processingState;
|
||||
},
|
||||
getProcessingDetails,
|
||||
getTechnicalDetails,
|
||||
getProcessingMessage,
|
||||
getPromptProgressText,
|
||||
getLiveProcessingStats,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,16 @@
|
|||
import type { Plugin } from 'unified';
|
||||
import type { Root, Element, ElementContent } from 'hast';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import {
|
||||
CODE_BLOCK_SCROLL_CONTAINER_CLASS,
|
||||
CODE_BLOCK_WRAPPER_CLASS,
|
||||
CODE_BLOCK_HEADER_CLASS,
|
||||
CODE_BLOCK_ACTIONS_CLASS,
|
||||
CODE_LANGUAGE_CLASS,
|
||||
COPY_CODE_BTN_CLASS,
|
||||
PREVIEW_CODE_BTN_CLASS,
|
||||
RELATIVE_CLASS
|
||||
} from '$lib/constants/code-blocks';
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
|
|
@ -42,7 +52,7 @@ function createCopyButton(codeId: string): Element {
|
|||
type: 'element',
|
||||
tagName: 'button',
|
||||
properties: {
|
||||
className: ['copy-code-btn'],
|
||||
className: [COPY_CODE_BTN_CLASS],
|
||||
'data-code-id': codeId,
|
||||
title: 'Copy code',
|
||||
type: 'button'
|
||||
|
|
@ -56,7 +66,7 @@ function createPreviewButton(codeId: string): Element {
|
|||
type: 'element',
|
||||
tagName: 'button',
|
||||
properties: {
|
||||
className: ['preview-code-btn'],
|
||||
className: [PREVIEW_CODE_BTN_CLASS],
|
||||
'data-code-id': codeId,
|
||||
title: 'Preview code',
|
||||
type: 'button'
|
||||
|
|
@ -75,30 +85,39 @@ function createHeader(language: string, codeId: string): Element {
|
|||
return {
|
||||
type: 'element',
|
||||
tagName: 'div',
|
||||
properties: { className: ['code-block-header'] },
|
||||
properties: { className: [CODE_BLOCK_HEADER_CLASS] },
|
||||
children: [
|
||||
{
|
||||
type: 'element',
|
||||
tagName: 'span',
|
||||
properties: { className: ['code-language'] },
|
||||
properties: { className: [CODE_LANGUAGE_CLASS] },
|
||||
children: [{ type: 'text', value: language }]
|
||||
},
|
||||
{
|
||||
type: 'element',
|
||||
tagName: 'div',
|
||||
properties: { className: ['code-block-actions'] },
|
||||
properties: { className: [CODE_BLOCK_ACTIONS_CLASS] },
|
||||
children: actions
|
||||
}
|
||||
]
|
||||
};
|
||||
}
|
||||
|
||||
function createScrollContainer(preElement: Element): Element {
|
||||
return {
|
||||
type: 'element',
|
||||
tagName: 'div',
|
||||
properties: { className: [CODE_BLOCK_SCROLL_CONTAINER_CLASS] },
|
||||
children: [preElement]
|
||||
};
|
||||
}
|
||||
|
||||
function createWrapper(header: Element, preElement: Element): Element {
|
||||
return {
|
||||
type: 'element',
|
||||
tagName: 'div',
|
||||
properties: { className: ['code-block-wrapper'] },
|
||||
children: [header, preElement]
|
||||
properties: { className: [CODE_BLOCK_WRAPPER_CLASS, RELATIVE_CLASS] },
|
||||
children: [header, createScrollContainer(preElement)]
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ export class ChatService {
|
|||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
disableReasoningFormat
|
||||
disableReasoningParsing
|
||||
} = options;
|
||||
|
||||
const normalizedMessages: ApiChatMessageData[] = messages
|
||||
|
|
@ -127,7 +127,7 @@ export class ChatService {
|
|||
requestBody.model = options.model;
|
||||
}
|
||||
|
||||
requestBody.reasoning_format = disableReasoningFormat ? 'none' : 'auto';
|
||||
requestBody.reasoning_format = disableReasoningParsing ? 'none' : 'auto';
|
||||
|
||||
if (temperature !== undefined) requestBody.temperature = temperature;
|
||||
if (max_tokens !== undefined) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,368 @@
|
|||
import Dexie, { type EntityTable } from 'dexie';
|
||||
import { findDescendantMessages } from '$lib/utils';
|
||||
|
||||
class LlamacppDatabase extends Dexie {
|
||||
conversations!: EntityTable<DatabaseConversation, string>;
|
||||
messages!: EntityTable<DatabaseMessage, string>;
|
||||
|
||||
constructor() {
|
||||
super('LlamacppWebui');
|
||||
|
||||
this.version(1).stores({
|
||||
conversations: 'id, lastModified, currNode, name',
|
||||
messages: 'id, convId, type, role, timestamp, parent, children'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const db = new LlamacppDatabase();
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { MessageRole } from '$lib/enums/chat';
|
||||
|
||||
export class DatabaseService {
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Conversations
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a new conversation.
|
||||
*
|
||||
* @param name - Name of the conversation
|
||||
* @returns The created conversation
|
||||
*/
|
||||
static async createConversation(name: string): Promise<DatabaseConversation> {
|
||||
const conversation: DatabaseConversation = {
|
||||
id: uuid(),
|
||||
name,
|
||||
lastModified: Date.now(),
|
||||
currNode: ''
|
||||
};
|
||||
|
||||
await db.conversations.add(conversation);
|
||||
return conversation;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Messages
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a new message branch by adding a message and updating parent/child relationships.
|
||||
* Also updates the conversation's currNode to point to the new message.
|
||||
*
|
||||
* @param message - Message to add (without id)
|
||||
* @param parentId - Parent message ID to attach to
|
||||
* @returns The created message
|
||||
*/
|
||||
static async createMessageBranch(
|
||||
message: Omit<DatabaseMessage, 'id'>,
|
||||
parentId: string | null
|
||||
): Promise<DatabaseMessage> {
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
// Handle null parent (root message case)
|
||||
if (parentId !== null) {
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (!parentMessage) {
|
||||
throw new Error(`Parent message ${parentId} not found`);
|
||||
}
|
||||
}
|
||||
|
||||
const newMessage: DatabaseMessage = {
|
||||
...message,
|
||||
id: uuid(),
|
||||
parent: parentId,
|
||||
toolCalls: message.toolCalls ?? '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(newMessage);
|
||||
|
||||
// Update parent's children array if parent exists
|
||||
if (parentId !== null) {
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, newMessage.id]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
await this.updateConversation(message.convId, {
|
||||
currNode: newMessage.id
|
||||
});
|
||||
|
||||
return newMessage;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a root message for a new conversation.
|
||||
* Root messages are not displayed but serve as the tree root for branching.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @returns The created root message
|
||||
*/
|
||||
static async createRootMessage(convId: string): Promise<string> {
|
||||
const rootMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'root',
|
||||
timestamp: Date.now(),
|
||||
role: MessageRole.SYSTEM,
|
||||
content: '',
|
||||
parent: null,
|
||||
toolCalls: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(rootMessage);
|
||||
return rootMessage.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a system prompt message for a conversation.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||
* @param parentId - Parent message ID (typically the root message)
|
||||
* @returns The created system message
|
||||
* @throws Error if systemPrompt is empty
|
||||
*/
|
||||
static async createSystemMessage(
|
||||
convId: string,
|
||||
systemPrompt: string,
|
||||
parentId: string
|
||||
): Promise<DatabaseMessage> {
|
||||
const trimmedPrompt = systemPrompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
throw new Error('Cannot create system message with empty content');
|
||||
}
|
||||
|
||||
const systemMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: MessageRole.SYSTEM,
|
||||
timestamp: Date.now(),
|
||||
role: MessageRole.SYSTEM,
|
||||
content: trimmedPrompt,
|
||||
parent: parentId,
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(systemMessage);
|
||||
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, systemMessage.id]
|
||||
});
|
||||
}
|
||||
|
||||
return systemMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
*/
|
||||
static async deleteConversation(id: string): Promise<void> {
|
||||
await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
await db.conversations.delete(id);
|
||||
await db.messages.where('convId').equals(id).delete();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a message and removes it from its parent's children array.
|
||||
*
|
||||
* @param messageId - ID of the message to delete
|
||||
*/
|
||||
static async deleteMessage(messageId: string): Promise<void> {
|
||||
await db.transaction('rw', db.messages, async () => {
|
||||
const message = await db.messages.get(messageId);
|
||||
if (!message) return;
|
||||
|
||||
// Remove this message from its parent's children array
|
||||
if (message.parent) {
|
||||
const parent = await db.messages.get(message.parent);
|
||||
if (parent) {
|
||||
parent.children = parent.children.filter((childId: string) => childId !== messageId);
|
||||
await db.messages.put(parent);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the message
|
||||
await db.messages.delete(messageId);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a message and all its descendant messages (cascading deletion).
|
||||
* This removes the entire branch starting from the specified message.
|
||||
*
|
||||
* @param conversationId - ID of the conversation containing the message
|
||||
* @param messageId - ID of the root message to delete (along with all descendants)
|
||||
* @returns Array of all deleted message IDs
|
||||
*/
|
||||
static async deleteMessageCascading(
|
||||
conversationId: string,
|
||||
messageId: string
|
||||
): Promise<string[]> {
|
||||
return await db.transaction('rw', db.messages, async () => {
|
||||
// Get all messages in the conversation to find descendants
|
||||
const allMessages = await db.messages.where('convId').equals(conversationId).toArray();
|
||||
|
||||
// Find all descendant messages
|
||||
const descendants = findDescendantMessages(allMessages, messageId);
|
||||
const allToDelete = [messageId, ...descendants];
|
||||
|
||||
// Get the message to delete for parent cleanup
|
||||
const message = await db.messages.get(messageId);
|
||||
if (message && message.parent) {
|
||||
const parent = await db.messages.get(message.parent);
|
||||
if (parent) {
|
||||
parent.children = parent.children.filter((childId: string) => childId !== messageId);
|
||||
await db.messages.put(parent);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete all messages in the branch
|
||||
await db.messages.bulkDelete(allToDelete);
|
||||
|
||||
return allToDelete;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all conversations, sorted by last modified time (newest first).
|
||||
*
|
||||
* @returns Array of conversations
|
||||
*/
|
||||
static async getAllConversations(): Promise<DatabaseConversation[]> {
|
||||
return await db.conversations.orderBy('lastModified').reverse().toArray();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a conversation by ID.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
* @returns The conversation if found, otherwise undefined
|
||||
*/
|
||||
static async getConversation(id: string): Promise<DatabaseConversation | undefined> {
|
||||
return await db.conversations.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all messages in a conversation, sorted by timestamp (oldest first).
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @returns Array of messages in the conversation
|
||||
*/
|
||||
static async getConversationMessages(convId: string): Promise<DatabaseMessage[]> {
|
||||
return await db.messages.where('convId').equals(convId).sortBy('timestamp');
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a conversation.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
* @param updates - Partial updates to apply
|
||||
* @returns Promise that resolves when the conversation is updated
|
||||
*/
|
||||
static async updateConversation(
|
||||
id: string,
|
||||
updates: Partial<Omit<DatabaseConversation, 'id'>>
|
||||
): Promise<void> {
|
||||
await db.conversations.update(id, {
|
||||
...updates,
|
||||
lastModified: Date.now()
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Navigation
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Updates the conversation's current node (active branch).
|
||||
* This determines which conversation path is currently being viewed.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param nodeId - Message ID to set as current node
|
||||
*/
|
||||
static async updateCurrentNode(convId: string, nodeId: string): Promise<void> {
|
||||
await this.updateConversation(convId, {
|
||||
currNode: nodeId
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @param id - Message ID
|
||||
* @param updates - Partial updates to apply
|
||||
* @returns Promise that resolves when the message is updated
|
||||
*/
|
||||
static async updateMessage(
|
||||
id: string,
|
||||
updates: Partial<Omit<DatabaseMessage, 'id'>>
|
||||
): Promise<void> {
|
||||
await db.messages.update(id, updates);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Import
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Imports multiple conversations and their messages.
|
||||
* Skips conversations that already exist.
|
||||
*
|
||||
* @param data - Array of { conv, messages } objects
|
||||
*/
|
||||
static async importConversations(
|
||||
data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[]
|
||||
): Promise<{ imported: number; skipped: number }> {
|
||||
let importedCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
for (const item of data) {
|
||||
const { conv, messages } = item;
|
||||
|
||||
const existing = await db.conversations.get(conv.id);
|
||||
if (existing) {
|
||||
console.warn(`Conversation "${conv.name}" already exists, skipping...`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
await db.conversations.add(conv);
|
||||
for (const msg of messages) {
|
||||
await db.messages.put(msg);
|
||||
}
|
||||
|
||||
importedCount++;
|
||||
}
|
||||
|
||||
return { imported: importedCount, skipped: skippedCount };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { apiFetch, apiPost } from '$lib/utils/api-fetch';
|
||||
|
||||
export class ModelsService {
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Listing
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Fetch list of models from OpenAI-compatible endpoint.
|
||||
* Works in both MODEL and ROUTER modes.
|
||||
*
|
||||
* @returns List of available models with basic metadata
|
||||
*/
|
||||
static async list(): Promise<ApiModelListResponse> {
|
||||
return apiFetch<ApiModelListResponse>('/v1/models');
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch list of all models with detailed metadata (ROUTER mode).
|
||||
* Returns models with load status, paths, and other metadata
|
||||
* beyond what the OpenAI-compatible endpoint provides.
|
||||
*
|
||||
* @returns List of models with detailed status and configuration info
|
||||
*/
|
||||
static async listRouter(): Promise<ApiRouterModelsListResponse> {
|
||||
return apiFetch<ApiRouterModelsListResponse>('/v1/models');
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Load/Unload
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Load a model (ROUTER mode only).
|
||||
* Sends POST request to `/models/load`. Note: the endpoint returns success
|
||||
* before loading completes — use polling to await actual load status.
|
||||
*
|
||||
* @param modelId - Model identifier to load
|
||||
* @param extraArgs - Optional additional arguments to pass to the model instance
|
||||
* @returns Load response from the server
|
||||
*/
|
||||
static async load(modelId: string, extraArgs?: string[]): Promise<ApiRouterModelsLoadResponse> {
|
||||
const payload: { model: string; extra_args?: string[] } = { model: modelId };
|
||||
if (extraArgs && extraArgs.length > 0) {
|
||||
payload.extra_args = extraArgs;
|
||||
}
|
||||
|
||||
return apiPost<ApiRouterModelsLoadResponse>('/models/load', payload);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload a model (ROUTER mode only).
|
||||
* Sends POST request to `/models/unload`. Note: the endpoint returns success
|
||||
* before unloading completes — use polling to await actual unload status.
|
||||
*
|
||||
* @param modelId - Model identifier to unload
|
||||
* @returns Unload response from the server
|
||||
*/
|
||||
static async unload(modelId: string): Promise<ApiRouterModelsUnloadResponse> {
|
||||
return apiPost<ApiRouterModelsUnloadResponse>('/models/unload', { model: modelId });
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Status
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Check if a model is loaded based on its metadata.
|
||||
*
|
||||
* @param model - Model data entry from the API response
|
||||
* @returns True if the model status is LOADED
|
||||
*/
|
||||
static isModelLoaded(model: ApiModelDataEntry): boolean {
|
||||
return model.status.value === ServerModelStatus.LOADED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is currently loading.
|
||||
*
|
||||
* @param model - Model data entry from the API response
|
||||
* @returns True if the model status is LOADING
|
||||
*/
|
||||
static isModelLoading(model: ApiModelDataEntry): boolean {
|
||||
return model.status.value === ServerModelStatus.LOADING;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
import { describe, it, expect } from 'vitest';
|
||||
import { ParameterSyncService } from './parameter-sync.service';
|
||||
|
||||
describe('ParameterSyncService', () => {
|
||||
describe('roundFloatingPoint', () => {
|
||||
it('should fix JavaScript floating-point precision issues', () => {
|
||||
// Test the specific values from the screenshot
|
||||
const mockServerParams = {
|
||||
top_p: 0.949999988079071,
|
||||
min_p: 0.009999999776482582,
|
||||
temperature: 0.800000011920929,
|
||||
top_k: 40,
|
||||
samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature']
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Add other required fields to match the API type
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
max_tokens: -1,
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
// Check that the problematic floating-point values are rounded correctly
|
||||
expect(result.top_p).toBe(0.95);
|
||||
expect(result.min_p).toBe(0.01);
|
||||
expect(result.temperature).toBe(0.8);
|
||||
expect(result.top_k).toBe(40); // Integer should remain unchanged
|
||||
expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature');
|
||||
});
|
||||
|
||||
it('should preserve non-numeric values', () => {
|
||||
const mockServerParams = {
|
||||
samplers: ['top_k', 'temperature'],
|
||||
max_tokens: -1,
|
||||
temperature: 0.7
|
||||
};
|
||||
|
||||
const result = ParameterSyncService.extractServerDefaults({
|
||||
...mockServerParams,
|
||||
// Minimal required fields
|
||||
n_predict: 512,
|
||||
seed: -1,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typ_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
mirostat: 0,
|
||||
mirostat_tau: 5.0,
|
||||
mirostat_eta: 0.1,
|
||||
stop: [],
|
||||
n_keep: 0,
|
||||
n_discard: 0,
|
||||
ignore_eos: false,
|
||||
stream: true,
|
||||
logit_bias: [],
|
||||
n_probs: 0,
|
||||
min_keep: 0,
|
||||
grammar: '',
|
||||
grammar_lazy: false,
|
||||
grammar_triggers: [],
|
||||
preserved_tokens: [],
|
||||
chat_format: '',
|
||||
reasoning_format: '',
|
||||
reasoning_in_content: false,
|
||||
thinking_forced_open: false,
|
||||
'speculative.n_max': 0,
|
||||
'speculative.n_min': 0,
|
||||
'speculative.p_min': 0.0,
|
||||
timings_per_token: false,
|
||||
post_sampling_probs: false,
|
||||
lora: [],
|
||||
top_n_sigma: 0.0,
|
||||
dry_sequence_breakers: []
|
||||
} as ApiLlamaCppServerProps['default_generation_settings']['params']);
|
||||
|
||||
expect(result.samplers).toBe('top_k;temperature');
|
||||
expect(result.max_tokens).toBe(-1);
|
||||
expect(result.temperature).toBe(0.7);
|
||||
});
|
||||
|
||||
it('should merge webui settings from props when provided', () => {
|
||||
const result = ParameterSyncService.extractServerDefaults(null, {
|
||||
pasteLongTextToFileLen: 0,
|
||||
pdfAsImage: true,
|
||||
renderUserContentAsMarkdown: false,
|
||||
theme: 'dark'
|
||||
});
|
||||
|
||||
expect(result.pasteLongTextToFileLen).toBe(0);
|
||||
expect(result.pdfAsImage).toBe(true);
|
||||
expect(result.renderUserContentAsMarkdown).toBe(false);
|
||||
expect(result.theme).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
import { normalizeFloatingPoint } from '$lib/utils';
|
||||
import { SyncableParameterType, ParameterSource } from '$lib/enums/settings';
|
||||
|
||||
type ParameterValue = string | number | boolean;
|
||||
type ParameterRecord = Record<string, ParameterValue>;
|
||||
|
||||
interface ParameterInfo {
|
||||
value: string | number | boolean;
|
||||
source: ParameterSource;
|
||||
serverDefault?: string | number | boolean;
|
||||
userOverride?: string | number | boolean;
|
||||
}
|
||||
|
||||
interface SyncableParameter {
|
||||
key: string;
|
||||
serverKey: string;
|
||||
type: SyncableParameterType;
|
||||
canSync: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping of webui setting keys to server parameter keys.
|
||||
* Only parameters listed here can be synced from the server `/props` endpoint.
|
||||
* Each entry defines the webui key, corresponding server key, value type,
|
||||
* and whether sync is enabled.
|
||||
*/
|
||||
export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
{
|
||||
key: 'temperature',
|
||||
serverKey: 'temperature',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'top_k', serverKey: 'top_k', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{ key: 'top_p', serverKey: 'top_p', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{ key: 'min_p', serverKey: 'min_p', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{
|
||||
key: 'dynatemp_range',
|
||||
serverKey: 'dynatemp_range',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'dynatemp_exponent',
|
||||
serverKey: 'dynatemp_exponent',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'xtc_probability',
|
||||
serverKey: 'xtc_probability',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'xtc_threshold',
|
||||
serverKey: 'xtc_threshold',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'typ_p', serverKey: 'typ_p', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{
|
||||
key: 'repeat_last_n',
|
||||
serverKey: 'repeat_last_n',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'repeat_penalty',
|
||||
serverKey: 'repeat_penalty',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'presence_penalty',
|
||||
serverKey: 'presence_penalty',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'frequency_penalty',
|
||||
serverKey: 'frequency_penalty',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'dry_multiplier',
|
||||
serverKey: 'dry_multiplier',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'dry_base', serverKey: 'dry_base', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{
|
||||
key: 'dry_allowed_length',
|
||||
serverKey: 'dry_allowed_length',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'dry_penalty_last_n',
|
||||
serverKey: 'dry_penalty_last_n',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{ key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true },
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
serverKey: 'pasteLongTextToFileLen',
|
||||
type: SyncableParameterType.NUMBER,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'pdfAsImage',
|
||||
serverKey: 'pdfAsImage',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'showThoughtInProgress',
|
||||
serverKey: 'showThoughtInProgress',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'keepStatsVisible',
|
||||
serverKey: 'keepStatsVisible',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'showMessageStats',
|
||||
serverKey: 'showMessageStats',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
serverKey: 'askForTitleConfirmation',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'disableAutoScroll',
|
||||
serverKey: 'disableAutoScroll',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
serverKey: 'renderUserContentAsMarkdown',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'autoMicOnEmpty',
|
||||
serverKey: 'autoMicOnEmpty',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'pyInterpreterEnabled',
|
||||
serverKey: 'pyInterpreterEnabled',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'enableContinueGeneration',
|
||||
serverKey: 'enableContinueGeneration',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
}
|
||||
];
|
||||
|
||||
export class ParameterSyncService {
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Extraction
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Round floating-point numbers to avoid JavaScript precision issues.
|
||||
* E.g., 0.1 + 0.2 = 0.30000000000000004 → 0.3
|
||||
*
|
||||
* @param value - Parameter value to normalize
|
||||
* @returns Precision-normalized value
|
||||
*/
|
||||
private static roundFloatingPoint(value: ParameterValue): ParameterValue {
|
||||
return normalizeFloatingPoint(value) as ParameterValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract server default parameters that can be synced from `/props` response.
|
||||
* Handles both generation settings parameters and webui-specific settings.
|
||||
* Converts samplers array to semicolon-delimited string for UI display.
|
||||
*
|
||||
* @param serverParams - Raw generation settings from server `/props` endpoint
|
||||
* @param webuiSettings - Optional webui-specific settings from server
|
||||
* @returns Record of extracted parameter key-value pairs with normalized precision
|
||||
*/
|
||||
static extractServerDefaults(
|
||||
serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null,
|
||||
webuiSettings?: Record<string, string | number | boolean>
|
||||
): ParameterRecord {
|
||||
const extracted: ParameterRecord = {};
|
||||
|
||||
if (serverParams) {
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in serverParams) {
|
||||
const value = (serverParams as unknown as Record<string, ParameterValue>)[
|
||||
param.serverKey
|
||||
];
|
||||
if (value !== undefined) {
|
||||
// Apply precision rounding to avoid JavaScript floating-point issues
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle samplers array conversion to string
|
||||
if (serverParams.samplers && Array.isArray(serverParams.samplers)) {
|
||||
extracted.samplers = serverParams.samplers.join(';');
|
||||
}
|
||||
}
|
||||
|
||||
if (webuiSettings) {
|
||||
for (const param of SYNCABLE_PARAMETERS) {
|
||||
if (param.canSync && param.serverKey in webuiSettings) {
|
||||
const value = webuiSettings[param.serverKey];
|
||||
if (value !== undefined) {
|
||||
extracted[param.key] = this.roundFloatingPoint(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return extracted;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Merging
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Merge server defaults with current user settings.
|
||||
* User overrides always take priority — only parameters not in `userOverrides`
|
||||
* set will be updated from server defaults.
|
||||
*
|
||||
* @param currentSettings - Current parameter values in the settings store
|
||||
* @param serverDefaults - Default values extracted from server props
|
||||
* @param userOverrides - Set of parameter keys explicitly overridden by the user
|
||||
* @returns Merged parameter record with user overrides preserved
|
||||
*/
|
||||
static mergeWithServerDefaults(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord,
|
||||
userOverrides: Set<string> = new Set()
|
||||
): ParameterRecord {
|
||||
const merged = { ...currentSettings };
|
||||
|
||||
for (const [key, serverValue] of Object.entries(serverDefaults)) {
|
||||
// Only update if user hasn't explicitly overridden this parameter
|
||||
if (!userOverrides.has(key)) {
|
||||
merged[key] = this.roundFloatingPoint(serverValue);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Info
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Get parameter information including source and values.
|
||||
* Used by ChatSettingsParameterSourceIndicator to display the correct badge
|
||||
* (Custom vs Default) for each parameter in the settings UI.
|
||||
*
|
||||
* @param key - The parameter key to get info for
|
||||
* @param currentValue - The current value of the parameter
|
||||
* @param propsDefaults - Server default values from `/props`
|
||||
* @param userOverrides - Set of parameter keys explicitly overridden by the user
|
||||
* @returns Parameter info with source, server default, and user override values
|
||||
*/
|
||||
static getParameterInfo(
|
||||
key: string,
|
||||
currentValue: ParameterValue,
|
||||
propsDefaults: ParameterRecord,
|
||||
userOverrides: Set<string>
|
||||
): ParameterInfo {
|
||||
const hasPropsDefault = propsDefaults[key] !== undefined;
|
||||
const isUserOverride = userOverrides.has(key);
|
||||
|
||||
// Simple logic: either using default (from props) or custom (user override)
|
||||
const source = isUserOverride ? ParameterSource.CUSTOM : ParameterSource.DEFAULT;
|
||||
|
||||
return {
|
||||
value: currentValue,
|
||||
source,
|
||||
serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility
|
||||
userOverride: isUserOverride ? currentValue : undefined
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a parameter can be synced from server.
|
||||
*
|
||||
* @param key - The parameter key to check
|
||||
* @returns True if the parameter is in the syncable parameters list
|
||||
*/
|
||||
static canSyncParameter(key: string): boolean {
|
||||
return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all syncable parameter keys.
|
||||
*
|
||||
* @returns Array of parameter keys that can be synced from server
|
||||
*/
|
||||
static getSyncableParameterKeys(): string[] {
|
||||
return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a server parameter value against its expected type.
|
||||
*
|
||||
* @param key - The parameter key to validate
|
||||
* @param value - The value to validate
|
||||
* @returns True if value matches the expected type for this parameter
|
||||
*/
|
||||
static validateServerParameter(key: string, value: ParameterValue): boolean {
|
||||
const param = SYNCABLE_PARAMETERS.find((p) => p.key === key);
|
||||
if (!param) return false;
|
||||
|
||||
switch (param.type) {
|
||||
case SyncableParameterType.NUMBER:
|
||||
return typeof value === 'number' && !isNaN(value);
|
||||
case SyncableParameterType.STRING:
|
||||
return typeof value === 'string';
|
||||
case SyncableParameterType.BOOLEAN:
|
||||
return typeof value === 'boolean';
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Diff
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Create a diff between current settings and server defaults.
|
||||
* Shows which parameters differ from server values, useful for debugging
|
||||
* and for the "Reset to defaults" functionality.
|
||||
*
|
||||
* @param currentSettings - Current parameter values in the settings store
|
||||
* @param serverDefaults - Default values extracted from server props
|
||||
* @returns Record of parameter diffs with current value, server value, and whether they differ
|
||||
*/
|
||||
static createParameterDiff(
|
||||
currentSettings: ParameterRecord,
|
||||
serverDefaults: ParameterRecord
|
||||
): Record<string, { current: ParameterValue; server: ParameterValue; differs: boolean }> {
|
||||
const diff: Record<
|
||||
string,
|
||||
{ current: ParameterValue; server: ParameterValue; differs: boolean }
|
||||
> = {};
|
||||
|
||||
for (const key of this.getSyncableParameterKeys()) {
|
||||
const currentValue = currentSettings[key];
|
||||
const serverValue = serverDefaults[key];
|
||||
|
||||
if (serverValue !== undefined) {
|
||||
diff[key] = {
|
||||
current: currentValue,
|
||||
server: serverValue,
|
||||
differs: currentValue !== serverValue
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return diff;
|
||||
}
|
||||
}
|
||||
|
|
@ -70,12 +70,6 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
|||
canSync: true
|
||||
},
|
||||
{ key: 'showToolCalls', serverKey: 'showToolCalls', type: 'boolean', canSync: true },
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
serverKey: 'disableReasoningFormat',
|
||||
type: 'boolean',
|
||||
canSync: true
|
||||
},
|
||||
{ key: 'keepStatsVisible', serverKey: 'keepStatsVisible', type: 'boolean', canSync: true },
|
||||
{ key: 'showMessageStats', serverKey: 'showMessageStats', type: 'boolean', canSync: true },
|
||||
{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
import { apiFetchWithParams } from '$lib/utils/api-fetch';
|
||||
|
||||
export class PropsService {
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Fetching
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Fetches global server properties from the `/props` endpoint.
|
||||
* In MODEL mode, returns modalities for the single loaded model.
|
||||
* In ROUTER mode, returns server-wide settings without model-specific modalities.
|
||||
*
|
||||
* @param autoload - If false, prevents automatic model loading (default: false)
|
||||
* @returns Server properties including default generation settings and capabilities
|
||||
* @throws {Error} If the request fails or returns invalid data
|
||||
*/
|
||||
static async fetch(autoload = false): Promise<ApiLlamaCppServerProps> {
|
||||
const params: Record<string, string> = {};
|
||||
if (!autoload) {
|
||||
params.autoload = 'false';
|
||||
}
|
||||
|
||||
return apiFetchWithParams<ApiLlamaCppServerProps>('./props', params, { authOnly: true });
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches server properties for a specific model (ROUTER mode only).
|
||||
* Required in ROUTER mode because global `/props` does not include per-model modalities.
|
||||
*
|
||||
* @param modelId - The model ID to fetch properties for
|
||||
* @param autoload - If false, prevents automatic model loading (default: false)
|
||||
* @returns Server properties specific to the requested model
|
||||
* @throws {Error} If the request fails, model not found, or model not loaded
|
||||
*/
|
||||
static async fetchForModel(modelId: string, autoload = false): Promise<ApiLlamaCppServerProps> {
|
||||
const params: Record<string, string> = { model: modelId };
|
||||
if (!autoload) {
|
||||
params.autoload = 'false';
|
||||
}
|
||||
|
||||
return apiFetchWithParams<ApiLlamaCppServerProps>('./props', params, { authOnly: true });
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue