Merge branch 'ggml-org:master' into power-law-sampler
This commit is contained in:
commit
dedbe36735
|
|
@ -873,7 +873,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
sampler_type_chars += common_sampler_type_to_chr(sampler);
|
sampler_type_chars += common_sampler_type_to_chr(sampler);
|
||||||
sampler_type_names += common_sampler_type_to_str(sampler) + ";";
|
sampler_type_names += common_sampler_type_to_str(sampler) + ";";
|
||||||
}
|
}
|
||||||
sampler_type_names.pop_back();
|
if (!sampler_type_names.empty()) {
|
||||||
|
sampler_type_names.pop_back(); // remove last semicolon
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -189,10 +189,10 @@ class ModelBase:
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||||
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
|
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||||
is_safetensors: bool = len(part_names) > 0
|
is_safetensors: bool = len(part_names) > 0
|
||||||
if not is_safetensors:
|
if not is_safetensors:
|
||||||
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
|
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||||
|
|
||||||
tensor_names_from_index: set[str] = set()
|
tensor_names_from_index: set[str] = set()
|
||||||
|
|
||||||
|
|
@ -209,7 +209,8 @@ class ModelBase:
|
||||||
if weight_map is None or not isinstance(weight_map, dict):
|
if weight_map is None or not isinstance(weight_map, dict):
|
||||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||||
tensor_names_from_index.update(weight_map.keys())
|
tensor_names_from_index.update(weight_map.keys())
|
||||||
part_names |= set(weight_map.values())
|
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
|
||||||
|
part_names = sorted(part_dict.keys())
|
||||||
else:
|
else:
|
||||||
weight_map = {}
|
weight_map = {}
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
if (GGML_RV_ZFH)
|
if (GGML_RV_ZFH)
|
||||||
string(APPEND MARCH_STR "_zfh")
|
string(APPEND MARCH_STR "_zfh")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (GGML_XTHEADVECTOR)
|
if (GGML_XTHEADVECTOR)
|
||||||
string(APPEND MARCH_STR "_xtheadvector")
|
string(APPEND MARCH_STR "_xtheadvector")
|
||||||
elseif (GGML_RVV)
|
elseif (GGML_RVV)
|
||||||
|
|
@ -465,6 +466,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
if (GGML_RV_ZVFH)
|
if (GGML_RV_ZVFH)
|
||||||
string(APPEND MARCH_STR "_zvfh")
|
string(APPEND MARCH_STR "_zvfh")
|
||||||
endif()
|
endif()
|
||||||
|
if (GGML_RV_ZVFBFWMA)
|
||||||
|
string(APPEND MARCH_STR "_zvfbfwma")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
if (GGML_RV_ZICBOP)
|
if (GGML_RV_ZICBOP)
|
||||||
string(APPEND MARCH_STR "_zicbop")
|
string(APPEND MARCH_STR "_zicbop")
|
||||||
|
|
|
||||||
|
|
@ -3320,13 +3320,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
|
||||||
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
||||||
_mm_storeu_ps(y + i, y_vec);
|
_mm_storeu_ps(y + i, y_vec);
|
||||||
}
|
}
|
||||||
#elif defined(__riscv_zvfh)
|
|
||||||
for (int vl; i < n; i += vl) {
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin)
|
||||||
vl = __riscv_vsetvl_e16m1(n - i);
|
// calculate step size
|
||||||
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
|
const int epr = __riscv_vsetvlmax_e16m2();
|
||||||
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
|
const int step = epr * 2;
|
||||||
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2
|
||||||
|
for (; i < np; i += step) {
|
||||||
|
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr);
|
||||||
|
vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr);
|
||||||
|
__riscv_vse32_v_f32m4(y + i, ay0, epr);
|
||||||
|
|
||||||
|
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr);
|
||||||
|
vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr);
|
||||||
|
__riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
int vl;
|
||||||
|
for (i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
|
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl);
|
||||||
|
vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl);
|
||||||
|
__riscv_vse32_v_f32m4(y + i, ay0, vl);
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
|
|
@ -3371,6 +3391,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
|
||||||
(const __m128i *)(x + i))),
|
(const __m128i *)(x + i))),
|
||||||
16)));
|
16)));
|
||||||
}
|
}
|
||||||
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin)
|
||||||
|
// calculate step size
|
||||||
|
const int epr = __riscv_vsetvlmax_e16m2();
|
||||||
|
const int step = epr * 2;
|
||||||
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2
|
||||||
|
for (; i < np; i += step) {
|
||||||
|
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr);
|
||||||
|
vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr);
|
||||||
|
__riscv_vse32_v_f32m4(y + i, ay0, epr);
|
||||||
|
|
||||||
|
vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr);
|
||||||
|
vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr);
|
||||||
|
__riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
int vl;
|
||||||
|
for (i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
|
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl);
|
||||||
|
vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl);
|
||||||
|
__riscv_vse32_v_f32m4(y + i, ay0, vl);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
for (; i < n; i++) {
|
for (; i < n; i++) {
|
||||||
y[i] = GGML_BF16_TO_FP32(x[i]);
|
y[i] = GGML_BF16_TO_FP32(x[i]);
|
||||||
|
|
|
||||||
|
|
@ -195,8 +195,48 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||||
sumf += (ggml_float)_mm_cvtss_f32(g);
|
sumf += (ggml_float)_mm_cvtss_f32(g);
|
||||||
|
|
||||||
#undef LOAD
|
#undef LOAD
|
||||||
#endif
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
|
||||||
|
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||||
|
|
||||||
|
// initialize accumulators to all zeroes
|
||||||
|
vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
|
vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
|
|
||||||
|
// calculate step size
|
||||||
|
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||||
|
const size_t step = epr * 2;
|
||||||
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2
|
||||||
|
for (; i < np; i += step) {
|
||||||
|
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
|
||||||
|
vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
|
||||||
|
vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
|
||||||
|
vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
|
||||||
|
vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
|
||||||
|
vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
// accumulate in 1 register
|
||||||
|
vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
|
vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
|
||||||
|
vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
|
||||||
|
vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce
|
||||||
|
vl = __riscv_vsetvlmax_e32m4();
|
||||||
|
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||||
|
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
|
||||||
|
|
||||||
|
#endif
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
|
sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
|
||||||
GGML_BF16_TO_FP32(y[i]));
|
GGML_BF16_TO_FP32(y[i]));
|
||||||
|
|
|
||||||
|
|
@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||||
}
|
}
|
||||||
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
||||||
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
||||||
#elif defined(__riscv_v_intrinsic)
|
|
||||||
// todo: RVV impl
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||||
for (int i = 0; i < n; ++i) {
|
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
|
||||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
// initialize accumulators to all zeroes
|
||||||
}
|
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
}
|
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
|
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
|
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||||
|
|
||||||
|
// calculate step size
|
||||||
|
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||||
|
const size_t step = epr * 2;
|
||||||
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2 along the row dimension
|
||||||
|
for (int i = 0; i < np; i += step) {
|
||||||
|
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
|
||||||
|
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
|
||||||
|
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
|
||||||
|
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
|
||||||
|
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
|
||||||
|
|
||||||
|
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
|
||||||
|
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
|
||||||
|
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
|
||||||
|
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
|
||||||
|
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
|
||||||
|
}
|
||||||
|
|
||||||
|
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
|
||||||
|
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m2(n - i);
|
||||||
|
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
|
||||||
|
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
|
||||||
|
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
|
||||||
|
|
||||||
|
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
|
||||||
|
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce
|
||||||
|
vl = __riscv_vsetvlmax_e32m2();
|
||||||
|
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
|
||||||
|
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
|
||||||
|
vl = __riscv_vsetvlmax_e32m1();
|
||||||
|
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
|
||||||
|
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
|
||||||
|
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||||
|
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||||
|
|
||||||
|
vl = __riscv_vsetvlmax_e32m2();
|
||||||
|
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
|
||||||
|
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
|
||||||
|
vl = __riscv_vsetvlmax_e32m1();
|
||||||
|
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
|
||||||
|
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
|
||||||
|
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||||
|
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||||
|
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
|
||||||
|
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
|
@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
|
||||||
}
|
}
|
||||||
np = n;
|
np = n;
|
||||||
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
||||||
const int np = n;
|
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||||
_Float16 hv = (_Float16)v;
|
const _Float16 scale = *(const _Float16*)(&s);
|
||||||
for (int i = 0, avl; i < n; i += avl) {
|
|
||||||
avl = __riscv_vsetvl_e16m8(n - i);
|
// calculate step size
|
||||||
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
|
const int epr = __riscv_vsetvlmax_e16m4();
|
||||||
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
|
const int step = epr * 2;
|
||||||
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
|
int np = (n & ~(step - 1));
|
||||||
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
|
|
||||||
|
// unroll by 2
|
||||||
|
for (int i = 0; i < np; i += step) {
|
||||||
|
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||||
|
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
|
||||||
|
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||||
|
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||||
|
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
int vl;
|
||||||
|
for (int i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m4(n - i);
|
||||||
|
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||||
|
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||||
|
}
|
||||||
|
np = n;
|
||||||
#elif defined(GGML_SIMD)
|
#elif defined(GGML_SIMD)
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
|
@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
||||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||||
}
|
}
|
||||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||||
for (int i = 0, vl; i < n; i += vl) {
|
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||||
vl = __riscv_vsetvl_e16m2(n - i);
|
const _Float16 scale = *(const _Float16*)(&s);
|
||||||
vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
|
|
||||||
vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
|
// calculate step size
|
||||||
vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
|
const int epr = __riscv_vsetvlmax_e16m4();
|
||||||
vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
|
const int step = epr * 2;
|
||||||
__riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
|
const int np = (n & ~(step - 1));
|
||||||
|
|
||||||
|
// unroll by 2
|
||||||
|
for (int i = 0; i < np; i += step) {
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||||
|
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
|
||||||
|
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||||
|
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||||
|
__asm__ __volatile__ ("" ::: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
int vl;
|
||||||
|
for (int i = np; i < n; i += vl) {
|
||||||
|
vl = __riscv_vsetvl_e16m4(n - i);
|
||||||
|
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||||
|
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
|
||||||
|
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||||
}
|
}
|
||||||
#elif defined(GGML_SIMD)
|
#elif defined(GGML_SIMD)
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
|
||||||
|
|
@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
|
||||||
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
||||||
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
||||||
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
||||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
|
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
||||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
||||||
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
|
|
||||||
};
|
};
|
||||||
// Implemented mma combinations are:
|
// Implemented mma combinations are:
|
||||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
|
|
||||||
constexpr bool is_i_major(const data_layout dl) {
|
static constexpr bool is_i_major(const data_layout dl) {
|
||||||
return dl == DATA_LAYOUT_I_MAJOR ||
|
return dl == DATA_LAYOUT_I_MAJOR ||
|
||||||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
|
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr data_layout get_input_data_layout() {
|
static constexpr __device__ data_layout get_input_data_layout() {
|
||||||
#if defined(RDNA3)
|
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
return DATA_LAYOUT_I_MAJOR_DUAL;
|
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
#else
|
#else
|
||||||
return DATA_LAYOUT_I_MAJOR;
|
return DATA_LAYOUT_I_MAJOR;
|
||||||
#endif // defined(RDNA3)
|
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||||
|
|
@ -462,11 +460,65 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_, typename T>
|
||||||
|
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
|
||||||
|
// RDNA3
|
||||||
|
static constexpr int ne = I * J / 32 * 2;
|
||||||
|
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
if (I == 16 && J == 4) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
#if defined(RDNA3)
|
||||||
|
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
||||||
|
|
||||||
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
||||||
|
}
|
||||||
|
#else // Volta
|
||||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||||
|
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
@ -493,6 +545,29 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(RDNA3)
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_>
|
||||||
|
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
||||||
|
|
||||||
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
|
|
@ -528,42 +603,6 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_, typename T>
|
|
||||||
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
|
|
||||||
static constexpr int I = I_;
|
|
||||||
static constexpr int J = J_;
|
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
|
|
||||||
static constexpr int ne = I * J / 32 * 2;
|
|
||||||
|
|
||||||
T x[ne] = {0};
|
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
|
||||||
if (I == 16 && J == 16) return true;
|
|
||||||
if (I == 16 && J == 8) return true;
|
|
||||||
if (I == 16 && J == 4) return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
if constexpr (supported()) {
|
|
||||||
return threadIdx.x % 16;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
|
||||||
if constexpr (supported()) {
|
|
||||||
return l;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(TURING_MMA_AVAILABLE)
|
#if defined(TURING_MMA_AVAILABLE)
|
||||||
template <int I, int J>
|
template <int I, int J>
|
||||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||||
|
|
|
||||||
|
|
@ -288,7 +288,7 @@ class LocalTensor:
|
||||||
data_range: LocalTensorRange
|
data_range: LocalTensorRange
|
||||||
|
|
||||||
def mmap_bytes(self) -> np.ndarray:
|
def mmap_bytes(self) -> np.ndarray:
|
||||||
return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size)
|
return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
|
||||||
|
|
||||||
|
|
||||||
class SafetensorsLocal:
|
class SafetensorsLocal:
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,10 @@
|
||||||
#ifdef __has_include
|
#ifdef __has_include
|
||||||
#if __has_include(<unistd.h>)
|
#if __has_include(<unistd.h>)
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
#if defined(_POSIX_MAPPED_FILES)
|
#if defined(_POSIX_MAPPED_FILES)
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
#include <fcntl.h>
|
|
||||||
#endif
|
#endif
|
||||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||||
#include <sys/resource.h>
|
#include <sys/resource.h>
|
||||||
|
|
@ -74,7 +75,7 @@ struct llama_file::impl {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl(const char * fname, const char * mode) {
|
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
|
||||||
fp = ggml_fopen(fname, mode);
|
fp = ggml_fopen(fname, mode);
|
||||||
if (fp == NULL) {
|
if (fp == NULL) {
|
||||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||||
|
|
@ -153,13 +154,40 @@ struct llama_file::impl {
|
||||||
write_raw(&val, sizeof(val));
|
write_raw(&val, sizeof(val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
|
||||||
|
throw std::runtime_error("DirectIO is not implemented on Windows.");
|
||||||
|
}
|
||||||
|
|
||||||
~impl() {
|
~impl() {
|
||||||
if (fp) {
|
if (fp) {
|
||||||
std::fclose(fp);
|
std::fclose(fp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
impl(const char * fname, const char * mode) {
|
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
|
||||||
|
#ifdef __linux__
|
||||||
|
// Try unbuffered I/O for read only
|
||||||
|
if (use_direct_io && std::strcmp(mode, "rb") == 0) {
|
||||||
|
fd = open(fname, O_RDONLY | O_DIRECT);
|
||||||
|
|
||||||
|
if (fd != -1) {
|
||||||
|
struct stat file_stats{};
|
||||||
|
fstat(fd, &file_stats);
|
||||||
|
|
||||||
|
size = file_stats.st_size;
|
||||||
|
alignment = file_stats.st_blksize;
|
||||||
|
|
||||||
|
off_t ret = lseek(fd, 0, SEEK_SET);
|
||||||
|
if (ret == -1) {
|
||||||
|
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_WARN("Failed to open model %s with error: %s. Falling back to buffered I/O",
|
||||||
|
fname, strerror(errno));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
fp = ggml_fopen(fname, mode);
|
fp = ggml_fopen(fname, mode);
|
||||||
if (fp == NULL) {
|
if (fp == NULL) {
|
||||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||||
|
|
@ -170,27 +198,30 @@ struct llama_file::impl {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t tell() const {
|
size_t tell() const {
|
||||||
// TODO: this ifdef is never true?
|
if (fd == -1) {
|
||||||
#ifdef _WIN32
|
long ret = std::ftell(fp);
|
||||||
__int64 ret = _ftelli64(fp);
|
if (ret == -1) {
|
||||||
#else
|
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||||
long ret = std::ftell(fp);
|
}
|
||||||
#endif
|
|
||||||
if (ret == -1) {
|
return (size_t) ret;
|
||||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (size_t) ret;
|
off_t pos = lseek(fd, 0, SEEK_CUR);
|
||||||
|
if (pos == -1) {
|
||||||
|
throw std::runtime_error(format("lseek error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
return (size_t) pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
void seek(size_t offset, int whence) const {
|
void seek(size_t offset, int whence) const {
|
||||||
// TODO: this ifdef is never true?
|
off_t ret = 0;
|
||||||
#ifdef _WIN32
|
if (fd == -1) {
|
||||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
ret = std::fseek(fp, (long) offset, whence);
|
||||||
#else
|
} else {
|
||||||
int ret = std::fseek(fp, (long) offset, whence);
|
ret = lseek(fd, offset, whence);
|
||||||
#endif
|
}
|
||||||
if (ret != 0) {
|
if (ret == -1) {
|
||||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -200,13 +231,55 @@ struct llama_file::impl {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
errno = 0;
|
errno = 0;
|
||||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
if (fd == -1) {
|
||||||
if (ferror(fp)) {
|
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
if (ferror(fp)) {
|
||||||
|
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
if (ret != 1) {
|
||||||
|
throw std::runtime_error("unexpectedly reached end of file");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool successful = false;
|
||||||
|
while (!successful) {
|
||||||
|
off_t ret = read(fd, ptr, len);
|
||||||
|
|
||||||
|
if (ret == -1) {
|
||||||
|
if (errno == EINTR) {
|
||||||
|
continue; // Interrupted by signal, retry
|
||||||
|
}
|
||||||
|
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||||
|
}
|
||||||
|
if (ret == 0) {
|
||||||
|
throw std::runtime_error("unexpectedly reached end of file");
|
||||||
|
}
|
||||||
|
|
||||||
|
successful = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (ret != 1) {
|
}
|
||||||
throw std::runtime_error("unexpectedly reached end of file");
|
|
||||||
|
void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
|
||||||
|
off_t aligned_offset = offset & ~(alignment - 1);
|
||||||
|
off_t offset_from_alignment = offset - aligned_offset;
|
||||||
|
size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1);
|
||||||
|
|
||||||
|
void * raw_buffer = nullptr;
|
||||||
|
int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read);
|
||||||
|
if (ret != 0) {
|
||||||
|
throw std::runtime_error(format("posix_memalign failed with error %d", ret));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct aligned_buffer_deleter {
|
||||||
|
void operator()(void * p) const { free(p); }
|
||||||
|
};
|
||||||
|
std::unique_ptr<void, aligned_buffer_deleter> buffer(raw_buffer);
|
||||||
|
|
||||||
|
seek(aligned_offset, SEEK_SET);
|
||||||
|
read_raw(buffer.get(), bytes_to_read);
|
||||||
|
|
||||||
|
uintptr_t actual_data = reinterpret_cast<uintptr_t>(buffer.get()) + offset_from_alignment;
|
||||||
|
memcpy(dest, reinterpret_cast<void *>(actual_data), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t read_u32() const {
|
uint32_t read_u32() const {
|
||||||
|
|
@ -231,22 +304,43 @@ struct llama_file::impl {
|
||||||
}
|
}
|
||||||
|
|
||||||
~impl() {
|
~impl() {
|
||||||
if (fp) {
|
if (fd != -1) {
|
||||||
|
close(fd);
|
||||||
|
} else {
|
||||||
std::fclose(fp);
|
std::fclose(fp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int fd = -1;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
FILE * fp;
|
void read_raw_at(void * ptr, size_t len, size_t offset) const {
|
||||||
size_t size;
|
if (alignment != 1) {
|
||||||
|
read_aligned_chunk(offset, ptr, len);
|
||||||
|
} else {
|
||||||
|
seek(offset, SEEK_SET);
|
||||||
|
read_raw(ptr, len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t read_alignment() const {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t alignment = 1;
|
||||||
|
|
||||||
|
FILE * fp{};
|
||||||
|
size_t size{};
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) :
|
||||||
|
pimpl(std::make_unique<impl>(fname, mode, use_direct_io)) {}
|
||||||
llama_file::~llama_file() = default;
|
llama_file::~llama_file() = default;
|
||||||
|
|
||||||
size_t llama_file::tell() const { return pimpl->tell(); }
|
size_t llama_file::tell() const { return pimpl->tell(); }
|
||||||
size_t llama_file::size() const { return pimpl->size; }
|
size_t llama_file::size() const { return pimpl->size; }
|
||||||
|
|
||||||
|
size_t llama_file::read_alignment() const { return pimpl->read_alignment(); }
|
||||||
|
|
||||||
int llama_file::file_id() const {
|
int llama_file::file_id() const {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return _fileno(pimpl->fp);
|
return _fileno(pimpl->fp);
|
||||||
|
|
@ -261,6 +355,7 @@ int llama_file::file_id() const {
|
||||||
|
|
||||||
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
||||||
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
||||||
|
void llama_file::read_raw_at(void * ptr, size_t len, size_t offset) const { pimpl->read_raw_at(ptr, len, offset); }
|
||||||
|
|
||||||
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
struct llama_file;
|
struct llama_file;
|
||||||
struct llama_mmap;
|
struct llama_mmap;
|
||||||
|
|
@ -13,7 +14,7 @@ using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||||
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
||||||
|
|
||||||
struct llama_file {
|
struct llama_file {
|
||||||
llama_file(const char * fname, const char * mode);
|
llama_file(const char * fname, const char * mode, bool use_direct_io = false);
|
||||||
~llama_file();
|
~llama_file();
|
||||||
|
|
||||||
size_t tell() const;
|
size_t tell() const;
|
||||||
|
|
@ -24,11 +25,14 @@ struct llama_file {
|
||||||
void seek(size_t offset, int whence) const;
|
void seek(size_t offset, int whence) const;
|
||||||
|
|
||||||
void read_raw(void * ptr, size_t len) const;
|
void read_raw(void * ptr, size_t len) const;
|
||||||
|
void read_raw_at(void * ptr, size_t len, size_t offset) const;
|
||||||
|
void read_aligned_chunk(size_t offset, void * dest, size_t size) const;
|
||||||
uint32_t read_u32() const;
|
uint32_t read_u32() const;
|
||||||
|
|
||||||
void write_raw(const void * ptr, size_t len) const;
|
void write_raw(const void * ptr, size_t len) const;
|
||||||
void write_u32(uint32_t val) const;
|
void write_u32(uint32_t val) const;
|
||||||
|
|
||||||
|
size_t read_alignment() const;
|
||||||
private:
|
private:
|
||||||
struct impl;
|
struct impl;
|
||||||
std::unique_ptr<impl> pimpl;
|
std::unique_ptr<impl> pimpl;
|
||||||
|
|
|
||||||
|
|
@ -504,7 +504,7 @@ llama_model_loader::llama_model_loader(
|
||||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||||
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
|
||||||
|
|
||||||
files.emplace_back(new llama_file(fname.c_str(), "rb"));
|
files.emplace_back(new llama_file(fname.c_str(), "rb", !use_mmap));
|
||||||
contexts.emplace_back(ctx);
|
contexts.emplace_back(ctx);
|
||||||
|
|
||||||
// Save tensors data offset of the main file.
|
// Save tensors data offset of the main file.
|
||||||
|
|
@ -572,7 +572,7 @@ llama_model_loader::llama_model_loader(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
files.emplace_back(new llama_file(fname_split, "rb"));
|
files.emplace_back(new llama_file(fname_split, "rb", !use_mmap));
|
||||||
contexts.emplace_back(ctx);
|
contexts.emplace_back(ctx);
|
||||||
|
|
||||||
// Save tensors data offset info of the shard.
|
// Save tensors data offset info of the shard.
|
||||||
|
|
@ -935,7 +935,15 @@ bool llama_model_loader::load_all_data(
|
||||||
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
|
// 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
|
||||||
// NVMe raid configurations might require more / larger buffers.
|
// NVMe raid configurations might require more / larger buffers.
|
||||||
constexpr size_t n_buffers = 4;
|
constexpr size_t n_buffers = 4;
|
||||||
constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
|
|
||||||
|
size_t alignment = 1;
|
||||||
|
for (const auto & file : files) {
|
||||||
|
alignment = std::max(file->read_alignment(), alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer size: balance between memory usage and I/O efficiency
|
||||||
|
// 64MB works well for NVMe drives
|
||||||
|
const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024;
|
||||||
|
|
||||||
std::vector<ggml_backend_buffer_t> host_buffers;
|
std::vector<ggml_backend_buffer_t> host_buffers;
|
||||||
std::vector<ggml_backend_event_t> events;
|
std::vector<ggml_backend_event_t> events;
|
||||||
|
|
@ -985,6 +993,7 @@ bool llama_model_loader::load_all_data(
|
||||||
// If the backend is supported, create pinned memory buffers and events for synchronisation.
|
// If the backend is supported, create pinned memory buffers and events for synchronisation.
|
||||||
for (size_t idx = 0; idx < n_buffers; ++idx) {
|
for (size_t idx = 0; idx < n_buffers; ++idx) {
|
||||||
auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
|
auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size);
|
||||||
|
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
|
LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func,
|
||||||
ggml_backend_dev_name(dev));
|
ggml_backend_dev_name(dev));
|
||||||
|
|
@ -1066,9 +1075,9 @@ bool llama_model_loader::load_all_data(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const auto & file = files.at(weight->idx);
|
const auto & file = files.at(weight->idx);
|
||||||
|
|
||||||
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
if (ggml_backend_buffer_is_host(cur->buffer)) {
|
||||||
file->seek(weight->offs, SEEK_SET);
|
file->read_raw_at(cur->data, n_size, weight->offs);
|
||||||
file->read_raw(cur->data, n_size);
|
|
||||||
if (check_tensors) {
|
if (check_tensors) {
|
||||||
validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
|
validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
|
||||||
return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
|
return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
|
||||||
|
|
@ -1077,26 +1086,60 @@ bool llama_model_loader::load_all_data(
|
||||||
} else {
|
} else {
|
||||||
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
|
// If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
|
||||||
if (upload_backend) {
|
if (upload_backend) {
|
||||||
file->seek(weight->offs, SEEK_SET);
|
auto offset = (off_t) weight->offs;
|
||||||
|
alignment = file->read_alignment();
|
||||||
|
off_t aligned_offset = offset & ~(alignment - 1);
|
||||||
|
off_t offset_from_alignment = offset - aligned_offset;
|
||||||
|
file->seek(aligned_offset, SEEK_SET);
|
||||||
|
|
||||||
|
// Calculate aligned read boundaries
|
||||||
|
size_t read_start = aligned_offset;
|
||||||
|
size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1);
|
||||||
|
|
||||||
size_t bytes_read = 0;
|
size_t bytes_read = 0;
|
||||||
|
size_t data_read = 0; // Actual tensor data copied (excluding padding)
|
||||||
|
|
||||||
while (bytes_read < n_size) {
|
while (bytes_read < read_end - read_start) {
|
||||||
size_t read_iteration = std::min<size_t>(buffer_size, n_size - bytes_read);
|
size_t read_size = std::min<size_t>(buffer_size, read_end - read_start - bytes_read);
|
||||||
|
|
||||||
|
// Align the destination pointer within the pinned buffer
|
||||||
|
uintptr_t ptr_dest_aligned = (reinterpret_cast<uintptr_t>(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1);
|
||||||
|
|
||||||
|
// Wait for previous upload to complete before reusing buffer
|
||||||
ggml_backend_event_synchronize(events[buffer_idx]);
|
ggml_backend_event_synchronize(events[buffer_idx]);
|
||||||
file->read_raw(host_ptrs[buffer_idx], read_iteration);
|
|
||||||
ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
|
// Read aligned chunk from file
|
||||||
|
file->read_raw(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
|
||||||
|
|
||||||
|
// Calculate actual data portion (excluding alignment padding)
|
||||||
|
uintptr_t ptr_data = ptr_dest_aligned;
|
||||||
|
size_t data_to_copy = read_size;
|
||||||
|
|
||||||
|
// Skip alignment padding at start of first chunk
|
||||||
|
if (bytes_read == 0) {
|
||||||
|
ptr_data += offset_from_alignment;
|
||||||
|
data_to_copy -= offset_from_alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim alignment padding at end of last chunk
|
||||||
|
if (aligned_offset + bytes_read + read_size > offset + n_size) {
|
||||||
|
data_to_copy -= (read_end - (offset + n_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async upload actual data to GPU
|
||||||
|
ggml_backend_tensor_set_async(upload_backend, cur,
|
||||||
|
reinterpret_cast<void *>(ptr_data), data_read, data_to_copy);
|
||||||
ggml_backend_event_record(events[buffer_idx], upload_backend);
|
ggml_backend_event_record(events[buffer_idx], upload_backend);
|
||||||
|
|
||||||
bytes_read += read_iteration;
|
data_read += data_to_copy;
|
||||||
|
bytes_read += read_size;
|
||||||
|
|
||||||
++buffer_idx;
|
++buffer_idx;
|
||||||
buffer_idx %= n_buffers;
|
buffer_idx %= n_buffers;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
read_buf.resize(n_size);
|
read_buf.resize(n_size);
|
||||||
file->seek(weight->offs, SEEK_SET);
|
file->read_raw_at(read_buf.data(), n_size, weight->offs);
|
||||||
file->read_raw(read_buf.data(), n_size);
|
|
||||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
|
||||||
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
|
||||||
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
|
||||||
|
|
|
||||||
|
|
@ -2378,10 +2378,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
if (cpu_dev == nullptr) {
|
if (cpu_dev == nullptr) {
|
||||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||||
}
|
}
|
||||||
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
|
||||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
|
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
|
||||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||||
const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il);
|
const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il);
|
||||||
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
|
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
|
||||||
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa);
|
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa);
|
||||||
return {cpu_dev, &pimpl->cpu_buft_list};
|
return {cpu_dev, &pimpl->cpu_buft_list};
|
||||||
|
|
@ -6693,10 +6693,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
if (llama_supports_gpu_offload()) {
|
if (llama_supports_gpu_offload()) {
|
||||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
|
int n_repeating = n_gpu;
|
||||||
if (n_gpu_layers > (int) hparams.n_layer) {
|
if (n_repeating > 0) {
|
||||||
LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
|
LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
|
||||||
|
n_repeating--;
|
||||||
}
|
}
|
||||||
|
LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating);
|
||||||
|
|
||||||
const int max_backend_supported_layers = hparams.n_layer + 1;
|
const int max_backend_supported_layers = hparams.n_layer + 1;
|
||||||
const int max_offloadable_layers = hparams.n_layer + 1;
|
const int max_offloadable_layers = hparams.n_layer + 1;
|
||||||
|
|
|
||||||
|
|
@ -292,10 +292,6 @@ static void llama_params_fit_impl(
|
||||||
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||||
throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
|
throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
|
||||||
}
|
}
|
||||||
if (hp_ngl < 2*nd) {
|
|
||||||
throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least "
|
|
||||||
+ std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!tensor_buft_overrides) {
|
if (!tensor_buft_overrides) {
|
||||||
throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
|
throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
|
||||||
|
|
@ -362,8 +358,7 @@ static void llama_params_fit_impl(
|
||||||
auto set_ngl_tensor_split_tbo = [&](
|
auto set_ngl_tensor_split_tbo = [&](
|
||||||
const std::vector<ngl_t> & ngl_per_device,
|
const std::vector<ngl_t> & ngl_per_device,
|
||||||
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
||||||
llama_model_params & mparams,
|
llama_model_params & mparams) {
|
||||||
const bool add_nonrepeating) {
|
|
||||||
mparams.n_gpu_layers = 0;
|
mparams.n_gpu_layers = 0;
|
||||||
for (size_t id = 0; id < nd; id++) {
|
for (size_t id = 0; id < nd; id++) {
|
||||||
mparams.n_gpu_layers += ngl_per_device[id].n_layer;
|
mparams.n_gpu_layers += ngl_per_device[id].n_layer;
|
||||||
|
|
@ -371,13 +366,9 @@ static void llama_params_fit_impl(
|
||||||
tensor_split[id] = ngl_per_device[id].n_layer;
|
tensor_split[id] = ngl_per_device[id].n_layer;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl);
|
assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1);
|
||||||
uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides
|
uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides
|
||||||
|
|
||||||
if (add_nonrepeating) {
|
|
||||||
mparams.n_gpu_layers += 1;
|
|
||||||
tensor_split[nd - 1] += 1;
|
|
||||||
}
|
|
||||||
mparams.tensor_split = tensor_split;
|
mparams.tensor_split = tensor_split;
|
||||||
|
|
||||||
size_t itbo = 0;
|
size_t itbo = 0;
|
||||||
|
|
@ -408,10 +399,9 @@ static void llama_params_fit_impl(
|
||||||
auto get_memory_for_layers = [&](
|
auto get_memory_for_layers = [&](
|
||||||
const char * func_name,
|
const char * func_name,
|
||||||
const std::vector<ngl_t> & ngl_per_device,
|
const std::vector<ngl_t> & ngl_per_device,
|
||||||
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts,
|
const std::vector<ggml_backend_buffer_type_t> & overflow_bufts) -> std::vector<int64_t> {
|
||||||
const bool add_nonrepeating) -> std::vector<int64_t> {
|
|
||||||
llama_model_params mparams_copy = *mparams;
|
llama_model_params mparams_copy = *mparams;
|
||||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating);
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy);
|
||||||
|
|
||||||
const dmds_t dmd_nl = llama_get_device_memory_data(
|
const dmds_t dmd_nl = llama_get_device_memory_data(
|
||||||
path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
|
||||||
|
|
@ -469,9 +459,6 @@ static void llama_params_fit_impl(
|
||||||
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
|
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
|
||||||
}
|
}
|
||||||
|
|
||||||
// whether for the optimal memory use we expect to load at least some MoE tensors:
|
|
||||||
const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0;
|
|
||||||
|
|
||||||
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
|
std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
|
||||||
overflow_bufts.reserve(nd);
|
overflow_bufts.reserve(nd);
|
||||||
for (size_t id = 0; id < nd - 1; ++id) {
|
for (size_t id = 0; id < nd - 1; ++id) {
|
||||||
|
|
@ -480,7 +467,7 @@ static void llama_params_fit_impl(
|
||||||
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
|
||||||
|
|
||||||
std::vector<ngl_t> ngl_per_device(nd);
|
std::vector<ngl_t> ngl_per_device(nd);
|
||||||
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe);
|
std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
|
||||||
if (hp_nex > 0) {
|
if (hp_nex > 0) {
|
||||||
for (size_t id = 0; id < nd; id++) {
|
for (size_t id = 0; id < nd; id++) {
|
||||||
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
|
ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
|
||||||
|
|
@ -493,13 +480,14 @@ static void llama_params_fit_impl(
|
||||||
// - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target
|
// - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target
|
||||||
// - check memory use of our guess, replace either the low or high bound
|
// - check memory use of our guess, replace either the low or high bound
|
||||||
// - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits
|
// - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits
|
||||||
|
// - the last device has the output layer, which cannot be a partial layer
|
||||||
if (hp_nex == 0) {
|
if (hp_nex == 0) {
|
||||||
LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__);
|
LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__);
|
||||||
} else {
|
} else {
|
||||||
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
|
||||||
}
|
}
|
||||||
for (int id = nd - 1; id >= 0; id--) {
|
for (int id = nd - 1; id >= 0; id--) {
|
||||||
uint32_t n_unassigned = hp_ngl;
|
uint32_t n_unassigned = hp_ngl + 1;
|
||||||
for (size_t jd = id + 1; jd < nd; ++jd) {
|
for (size_t jd = id + 1; jd < nd; ++jd) {
|
||||||
assert(n_unassigned >= ngl_per_device[jd].n_layer);
|
assert(n_unassigned >= ngl_per_device[jd].n_layer);
|
||||||
n_unassigned -= ngl_per_device[jd].n_layer;
|
n_unassigned -= ngl_per_device[jd].n_layer;
|
||||||
|
|
@ -508,10 +496,10 @@ static void llama_params_fit_impl(
|
||||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||||
ngl_per_device_high[id].n_layer = n_unassigned;
|
ngl_per_device_high[id].n_layer = n_unassigned;
|
||||||
if (hp_nex > 0) {
|
if (hp_nex > 0) {
|
||||||
ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer;
|
ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1;
|
||||||
}
|
}
|
||||||
if (ngl_per_device_high[id].n_layer > 0) {
|
if (ngl_per_device_high[id].n_layer > 0) {
|
||||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
|
||||||
if (mem_high[id] > targets[id]) {
|
if (mem_high[id] > targets[id]) {
|
||||||
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
||||||
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
|
|
@ -526,7 +514,7 @@ static void llama_params_fit_impl(
|
||||||
if (hp_nex) {
|
if (hp_nex) {
|
||||||
ngl_per_device_test[id].n_part += step_size;
|
ngl_per_device_test[id].n_part += step_size;
|
||||||
}
|
}
|
||||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
|
|
||||||
if (mem_test[id] <= targets[id]) {
|
if (mem_test[id] <= targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
|
@ -553,7 +541,7 @@ static void llama_params_fit_impl(
|
||||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB);
|
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB);
|
||||||
}
|
}
|
||||||
if (hp_nex == 0 || global_surplus_cpu_moe <= 0) {
|
if (hp_nex == 0 || global_surplus_cpu_moe <= 0) {
|
||||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -576,13 +564,13 @@ static void llama_params_fit_impl(
|
||||||
for (size_t id = 0; id <= id_dense_start; id++) {
|
for (size_t id = 0; id <= id_dense_start; id++) {
|
||||||
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
|
||||||
for (size_t jd = id_dense_start; jd < nd; jd++) {
|
for (size_t jd = id_dense_start; jd < nd; jd++) {
|
||||||
const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer;
|
const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
|
||||||
ngl_per_device_high[id].n_layer += n_layer_move;
|
ngl_per_device_high[id].n_layer += n_layer_move;
|
||||||
ngl_per_device_high[jd].n_layer -= n_layer_move;
|
ngl_per_device_high[jd].n_layer -= n_layer_move;
|
||||||
ngl_per_device_high[jd].n_part = 0;
|
ngl_per_device_high[jd].n_part = 0;
|
||||||
}
|
}
|
||||||
size_t id_dense_start_high = nd - 1;
|
size_t id_dense_start_high = nd - 1;
|
||||||
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
|
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
|
||||||
|
|
||||||
if (mem_high[id] > targets[id]) {
|
if (mem_high[id] > targets[id]) {
|
||||||
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
|
assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
|
||||||
|
|
@ -610,7 +598,7 @@ static void llama_params_fit_impl(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
|
|
||||||
if (mem_test[id] <= targets[id]) {
|
if (mem_test[id] <= targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
|
|
@ -637,7 +625,7 @@ static void llama_params_fit_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to fit at least part of one more layer
|
// try to fit at least part of one more layer
|
||||||
if (ngl_per_device[id_dense_start].n_layer > 0) {
|
if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) {
|
||||||
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
|
||||||
size_t id_dense_start_test = id_dense_start;
|
size_t id_dense_start_test = id_dense_start;
|
||||||
ngl_per_device_test[id_dense_start_test].n_layer--;
|
ngl_per_device_test[id_dense_start_test].n_layer--;
|
||||||
|
|
@ -649,7 +637,7 @@ static void llama_params_fit_impl(
|
||||||
}
|
}
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
||||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
|
|
@ -659,7 +647,7 @@ static void llama_params_fit_impl(
|
||||||
|
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
||||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
|
|
@ -670,7 +658,7 @@ static void llama_params_fit_impl(
|
||||||
} else {
|
} else {
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
||||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id]) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
|
|
@ -687,7 +675,7 @@ static void llama_params_fit_impl(
|
||||||
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
__func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
|
||||||
}
|
}
|
||||||
|
|
||||||
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
|
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_params_fit(
|
bool llama_params_fit(
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -2109,9 +2109,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@sveltejs/kit": {
|
"node_modules/@sveltejs/kit": {
|
||||||
"version": "2.48.5",
|
"version": "2.49.2",
|
||||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz",
|
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.49.2.tgz",
|
||||||
"integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==",
|
"integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|
@ -5797,9 +5797,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/mdast-util-to-hast": {
|
"node_modules/mdast-util-to-hast": {
|
||||||
"version": "13.2.0",
|
"version": "13.2.1",
|
||||||
"resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.0.tgz",
|
"resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.1.tgz",
|
||||||
"integrity": "sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA==",
|
"integrity": "sha512-cctsq2wp5vTsLIcaymblUriiTcZd0CwWtCbLvrOzYCDZoWyMNV8sZ7krj09FSnsiJi3WVsHLM4k6Dq/yaPyCXA==",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/hast": "^3.0.0",
|
"@types/hast": "^3.0.0",
|
||||||
|
|
|
||||||
|
|
@ -124,3 +124,10 @@ declare global {
|
||||||
SettingsConfigType
|
SettingsConfigType
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
interface Window {
|
||||||
|
idxThemeStyle?: number;
|
||||||
|
idxCodeBlock?: number;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@
|
||||||
|
|
||||||
<div class="info my-6 grid gap-4">
|
<div class="info my-6 grid gap-4">
|
||||||
{#if displayedModel()}
|
{#if displayedModel()}
|
||||||
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
|
<div class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground">
|
||||||
{#if isRouter}
|
{#if isRouter}
|
||||||
<ModelsSelector
|
<ModelsSelector
|
||||||
currentModel={displayedModel()}
|
currentModel={displayedModel()}
|
||||||
|
|
@ -258,11 +258,13 @@
|
||||||
|
|
||||||
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
|
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
|
||||||
<ChatMessageStatistics
|
<ChatMessageStatistics
|
||||||
|
promptTokens={message.timings.prompt_n}
|
||||||
|
promptMs={message.timings.prompt_ms}
|
||||||
predictedTokens={message.timings.predicted_n}
|
predictedTokens={message.timings.predicted_n}
|
||||||
predictedMs={message.timings.predicted_ms}
|
predictedMs={message.timings.predicted_ms}
|
||||||
/>
|
/>
|
||||||
{/if}
|
{/if}
|
||||||
</span>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
|
||||||
{#if config().showToolCalls}
|
{#if config().showToolCalls}
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,122 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { Clock, Gauge, WholeWord } from '@lucide/svelte';
|
import { Clock, Gauge, WholeWord, BookOpenText, Sparkles } from '@lucide/svelte';
|
||||||
import { BadgeChatStatistic } from '$lib/components/app';
|
import { BadgeChatStatistic } from '$lib/components/app';
|
||||||
|
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||||
|
import { ChatMessageStatsView } from '$lib/enums';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
predictedTokens: number;
|
predictedTokens: number;
|
||||||
predictedMs: number;
|
predictedMs: number;
|
||||||
|
promptTokens?: number;
|
||||||
|
promptMs?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
let { predictedTokens, predictedMs }: Props = $props();
|
let { predictedTokens, predictedMs, promptTokens, promptMs }: Props = $props();
|
||||||
|
|
||||||
|
let activeView: ChatMessageStatsView = $state(ChatMessageStatsView.GENERATION);
|
||||||
|
|
||||||
let tokensPerSecond = $derived((predictedTokens / predictedMs) * 1000);
|
let tokensPerSecond = $derived((predictedTokens / predictedMs) * 1000);
|
||||||
let timeInSeconds = $derived((predictedMs / 1000).toFixed(2));
|
let timeInSeconds = $derived((predictedMs / 1000).toFixed(2));
|
||||||
|
|
||||||
|
let promptTokensPerSecond = $derived(
|
||||||
|
promptTokens !== undefined && promptMs !== undefined
|
||||||
|
? (promptTokens / promptMs) * 1000
|
||||||
|
: undefined
|
||||||
|
);
|
||||||
|
|
||||||
|
let promptTimeInSeconds = $derived(
|
||||||
|
promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasPromptStats = $derived(
|
||||||
|
promptTokens !== undefined &&
|
||||||
|
promptMs !== undefined &&
|
||||||
|
promptTokensPerSecond !== undefined &&
|
||||||
|
promptTimeInSeconds !== undefined
|
||||||
|
);
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<BadgeChatStatistic icon={WholeWord} value="{predictedTokens} tokens" />
|
<div class="inline-flex items-center text-xs text-muted-foreground">
|
||||||
|
<div class="inline-flex items-center rounded-sm bg-muted-foreground/15 p-0.5">
|
||||||
|
{#if hasPromptStats}
|
||||||
|
<Tooltip.Root>
|
||||||
|
<Tooltip.Trigger>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="inline-flex h-5 w-5 items-center justify-center rounded-sm transition-colors {activeView ===
|
||||||
|
ChatMessageStatsView.READING
|
||||||
|
? 'bg-background text-foreground shadow-sm'
|
||||||
|
: 'hover:text-foreground'}"
|
||||||
|
onclick={() => (activeView = ChatMessageStatsView.READING)}
|
||||||
|
>
|
||||||
|
<BookOpenText class="h-3 w-3" />
|
||||||
|
<span class="sr-only">Reading</span>
|
||||||
|
</button>
|
||||||
|
</Tooltip.Trigger>
|
||||||
|
<Tooltip.Content>
|
||||||
|
<p>Reading (prompt processing)</p>
|
||||||
|
</Tooltip.Content>
|
||||||
|
</Tooltip.Root>
|
||||||
|
{/if}
|
||||||
|
<Tooltip.Root>
|
||||||
|
<Tooltip.Trigger>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="inline-flex h-5 w-5 items-center justify-center rounded-sm transition-colors {activeView ===
|
||||||
|
ChatMessageStatsView.GENERATION
|
||||||
|
? 'bg-background text-foreground shadow-sm'
|
||||||
|
: 'hover:text-foreground'}"
|
||||||
|
onclick={() => (activeView = ChatMessageStatsView.GENERATION)}
|
||||||
|
>
|
||||||
|
<Sparkles class="h-3 w-3" />
|
||||||
|
<span class="sr-only">Generation</span>
|
||||||
|
</button>
|
||||||
|
</Tooltip.Trigger>
|
||||||
|
<Tooltip.Content>
|
||||||
|
<p>Generation (token output)</p>
|
||||||
|
</Tooltip.Content>
|
||||||
|
</Tooltip.Root>
|
||||||
|
</div>
|
||||||
|
|
||||||
<BadgeChatStatistic icon={Clock} value="{timeInSeconds}s" />
|
<div class="flex items-center gap-1 px-2">
|
||||||
|
{#if activeView === ChatMessageStatsView.GENERATION}
|
||||||
<BadgeChatStatistic icon={Gauge} value="{tokensPerSecond.toFixed(2)} tokens/s" />
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={WholeWord}
|
||||||
|
value="{predictedTokens} tokens"
|
||||||
|
tooltipLabel="Generated tokens"
|
||||||
|
/>
|
||||||
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={Clock}
|
||||||
|
value="{timeInSeconds}s"
|
||||||
|
tooltipLabel="Generation time"
|
||||||
|
/>
|
||||||
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={Gauge}
|
||||||
|
value="{tokensPerSecond.toFixed(2)} tokens/s"
|
||||||
|
tooltipLabel="Generation speed"
|
||||||
|
/>
|
||||||
|
{:else if hasPromptStats}
|
||||||
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={WholeWord}
|
||||||
|
value="{promptTokens} tokens"
|
||||||
|
tooltipLabel="Prompt tokens"
|
||||||
|
/>
|
||||||
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={Clock}
|
||||||
|
value="{promptTimeInSeconds}s"
|
||||||
|
tooltipLabel="Prompt processing time"
|
||||||
|
/>
|
||||||
|
<BadgeChatStatistic
|
||||||
|
class="bg-transparent"
|
||||||
|
icon={Gauge}
|
||||||
|
value="{promptTokensPerSecond!.toFixed(2)} tokens/s"
|
||||||
|
tooltipLabel="Prompt processing speed"
|
||||||
|
/>
|
||||||
|
{/if}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -587,7 +587,7 @@
|
||||||
|
|
||||||
&::after {
|
&::after {
|
||||||
content: '';
|
content: '';
|
||||||
position: fixed;
|
position: absolute;
|
||||||
bottom: 0;
|
bottom: 0;
|
||||||
z-index: -1;
|
z-index: -1;
|
||||||
left: 0;
|
left: 0;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { BadgeInfo } from '$lib/components/app';
|
import { BadgeInfo } from '$lib/components/app';
|
||||||
|
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||||
import { copyToClipboard } from '$lib/utils';
|
import { copyToClipboard } from '$lib/utils';
|
||||||
import type { Component } from 'svelte';
|
import type { Component } from 'svelte';
|
||||||
|
|
||||||
|
|
@ -7,19 +8,37 @@
|
||||||
class?: string;
|
class?: string;
|
||||||
icon: Component;
|
icon: Component;
|
||||||
value: string | number;
|
value: string | number;
|
||||||
|
tooltipLabel?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
let { class: className = '', icon: Icon, value }: Props = $props();
|
let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
|
||||||
|
|
||||||
function handleClick() {
|
function handleClick() {
|
||||||
void copyToClipboard(String(value));
|
void copyToClipboard(String(value));
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<BadgeInfo class={className} onclick={handleClick}>
|
{#if tooltipLabel}
|
||||||
{#snippet icon()}
|
<Tooltip.Root>
|
||||||
<Icon class="h-3 w-3" />
|
<Tooltip.Trigger>
|
||||||
{/snippet}
|
<BadgeInfo class={className} onclick={handleClick}>
|
||||||
|
{#snippet icon()}
|
||||||
|
<Icon class="h-3 w-3" />
|
||||||
|
{/snippet}
|
||||||
|
|
||||||
{value}
|
{value}
|
||||||
</BadgeInfo>
|
</BadgeInfo>
|
||||||
|
</Tooltip.Trigger>
|
||||||
|
<Tooltip.Content>
|
||||||
|
<p>{tooltipLabel}</p>
|
||||||
|
</Tooltip.Content>
|
||||||
|
</Tooltip.Root>
|
||||||
|
{:else}
|
||||||
|
<BadgeInfo class={className} onclick={handleClick}>
|
||||||
|
{#snippet icon()}
|
||||||
|
<Icon class="h-3 w-3" />
|
||||||
|
{/snippet}
|
||||||
|
|
||||||
|
{value}
|
||||||
|
</BadgeInfo>
|
||||||
|
{/if}
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,19 @@
|
||||||
import remarkRehype from 'remark-rehype';
|
import remarkRehype from 'remark-rehype';
|
||||||
import rehypeKatex from 'rehype-katex';
|
import rehypeKatex from 'rehype-katex';
|
||||||
import rehypeStringify from 'rehype-stringify';
|
import rehypeStringify from 'rehype-stringify';
|
||||||
import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils';
|
import type { Root as HastRoot, RootContent as HastRootContent } from 'hast';
|
||||||
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
|
import type { Root as MdastRoot } from 'mdast';
|
||||||
import { browser } from '$app/environment';
|
import { browser } from '$app/environment';
|
||||||
|
import { onDestroy, tick } from 'svelte';
|
||||||
|
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
|
||||||
|
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
|
||||||
|
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
|
||||||
|
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||||
|
import { copyCodeToClipboard, preprocessLaTeX } from '$lib/utils';
|
||||||
import '$styles/katex-custom.scss';
|
import '$styles/katex-custom.scss';
|
||||||
|
|
||||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||||
import { mode } from 'mode-watcher';
|
import { mode } from 'mode-watcher';
|
||||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
|
||||||
import CodePreviewDialog from './CodePreviewDialog.svelte';
|
import CodePreviewDialog from './CodePreviewDialog.svelte';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
|
|
@ -23,33 +27,24 @@
|
||||||
class?: string;
|
class?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface MarkdownBlock {
|
||||||
|
id: string;
|
||||||
|
html: string;
|
||||||
|
}
|
||||||
|
|
||||||
let { content, class: className = '' }: Props = $props();
|
let { content, class: className = '' }: Props = $props();
|
||||||
|
|
||||||
let containerRef = $state<HTMLDivElement>();
|
let containerRef = $state<HTMLDivElement>();
|
||||||
let processedHtml = $state('');
|
let renderedBlocks = $state<MarkdownBlock[]>([]);
|
||||||
|
let unstableBlockHtml = $state('');
|
||||||
let previewDialogOpen = $state(false);
|
let previewDialogOpen = $state(false);
|
||||||
let previewCode = $state('');
|
let previewCode = $state('');
|
||||||
let previewLanguage = $state('text');
|
let previewLanguage = $state('text');
|
||||||
|
|
||||||
function loadHighlightTheme(isDark: boolean) {
|
let pendingMarkdown: string | null = null;
|
||||||
if (!browser) return;
|
let isProcessing = false;
|
||||||
|
|
||||||
const existingThemes = document.querySelectorAll('style[data-highlight-theme]');
|
const themeStyleId = `highlight-theme-${(window.idxThemeStyle = (window.idxThemeStyle ?? 0) + 1)}`;
|
||||||
existingThemes.forEach((style) => style.remove());
|
|
||||||
|
|
||||||
const style = document.createElement('style');
|
|
||||||
style.setAttribute('data-highlight-theme', 'true');
|
|
||||||
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
|
||||||
|
|
||||||
document.head.appendChild(style);
|
|
||||||
}
|
|
||||||
|
|
||||||
$effect(() => {
|
|
||||||
const currentMode = mode.current;
|
|
||||||
const isDark = currentMode === 'dark';
|
|
||||||
|
|
||||||
loadHighlightTheme(isDark);
|
|
||||||
});
|
|
||||||
|
|
||||||
let processor = $derived(() => {
|
let processor = $derived(() => {
|
||||||
return remark()
|
return remark()
|
||||||
|
|
@ -61,139 +56,64 @@
|
||||||
.use(rehypeKatex) // Render math using KaTeX
|
.use(rehypeKatex) // Render math using KaTeX
|
||||||
.use(rehypeHighlight) // Add syntax highlighting
|
.use(rehypeHighlight) // Add syntax highlighting
|
||||||
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
|
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
|
||||||
.use(rehypeStringify); // Convert to HTML string
|
.use(rehypeEnhanceLinks) // Add target="_blank" to links
|
||||||
|
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
|
||||||
|
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
|
||||||
});
|
});
|
||||||
|
|
||||||
function enhanceLinks(html: string): string {
|
/**
|
||||||
if (!html.includes('<a')) {
|
* Removes click event listeners from copy and preview buttons.
|
||||||
return html;
|
* Called on component destroy.
|
||||||
|
*/
|
||||||
|
function cleanupEventListeners() {
|
||||||
|
if (!containerRef) return;
|
||||||
|
|
||||||
|
const copyButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||||
|
const previewButtons = containerRef.querySelectorAll<HTMLButtonElement>('.preview-code-btn');
|
||||||
|
|
||||||
|
for (const button of copyButtons) {
|
||||||
|
button.removeEventListener('click', handleCopyClick);
|
||||||
}
|
}
|
||||||
|
|
||||||
const tempDiv = document.createElement('div');
|
for (const button of previewButtons) {
|
||||||
tempDiv.innerHTML = html;
|
button.removeEventListener('click', handlePreviewClick);
|
||||||
|
|
||||||
// Make all links open in new tabs
|
|
||||||
const linkElements = tempDiv.querySelectorAll('a[href]');
|
|
||||||
let mutated = false;
|
|
||||||
|
|
||||||
for (const link of linkElements) {
|
|
||||||
const target = link.getAttribute('target');
|
|
||||||
const rel = link.getAttribute('rel');
|
|
||||||
|
|
||||||
if (target !== '_blank' || rel !== 'noopener noreferrer') {
|
|
||||||
mutated = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
link.setAttribute('target', '_blank');
|
|
||||||
link.setAttribute('rel', 'noopener noreferrer');
|
|
||||||
}
|
|
||||||
|
|
||||||
return mutated ? tempDiv.innerHTML : html;
|
|
||||||
}
|
|
||||||
|
|
||||||
function enhanceCodeBlocks(html: string): string {
|
|
||||||
if (!html.includes('<pre')) {
|
|
||||||
return html;
|
|
||||||
}
|
|
||||||
|
|
||||||
const tempDiv = document.createElement('div');
|
|
||||||
tempDiv.innerHTML = html;
|
|
||||||
|
|
||||||
const preElements = tempDiv.querySelectorAll('pre');
|
|
||||||
let mutated = false;
|
|
||||||
|
|
||||||
for (const [index, pre] of Array.from(preElements).entries()) {
|
|
||||||
const codeElement = pre.querySelector('code');
|
|
||||||
|
|
||||||
if (!codeElement) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
mutated = true;
|
|
||||||
|
|
||||||
let language = 'text';
|
|
||||||
const classList = Array.from(codeElement.classList);
|
|
||||||
|
|
||||||
for (const className of classList) {
|
|
||||||
if (className.startsWith('language-')) {
|
|
||||||
language = className.replace('language-', '');
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const rawCode = codeElement.textContent || '';
|
|
||||||
const codeId = `code-${Date.now()}-${index}`;
|
|
||||||
codeElement.setAttribute('data-code-id', codeId);
|
|
||||||
codeElement.setAttribute('data-raw-code', rawCode);
|
|
||||||
|
|
||||||
const wrapper = document.createElement('div');
|
|
||||||
wrapper.className = 'code-block-wrapper';
|
|
||||||
|
|
||||||
const header = document.createElement('div');
|
|
||||||
header.className = 'code-block-header';
|
|
||||||
|
|
||||||
const languageLabel = document.createElement('span');
|
|
||||||
languageLabel.className = 'code-language';
|
|
||||||
languageLabel.textContent = language;
|
|
||||||
|
|
||||||
const copyButton = document.createElement('button');
|
|
||||||
copyButton.className = 'copy-code-btn';
|
|
||||||
copyButton.setAttribute('data-code-id', codeId);
|
|
||||||
copyButton.setAttribute('title', 'Copy code');
|
|
||||||
copyButton.setAttribute('type', 'button');
|
|
||||||
|
|
||||||
copyButton.innerHTML = `
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>
|
|
||||||
`;
|
|
||||||
|
|
||||||
const actions = document.createElement('div');
|
|
||||||
actions.className = 'code-block-actions';
|
|
||||||
|
|
||||||
actions.appendChild(copyButton);
|
|
||||||
|
|
||||||
if (language.toLowerCase() === 'html') {
|
|
||||||
const previewButton = document.createElement('button');
|
|
||||||
previewButton.className = 'preview-code-btn';
|
|
||||||
previewButton.setAttribute('data-code-id', codeId);
|
|
||||||
previewButton.setAttribute('title', 'Preview code');
|
|
||||||
previewButton.setAttribute('type', 'button');
|
|
||||||
|
|
||||||
previewButton.innerHTML = `
|
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>
|
|
||||||
`;
|
|
||||||
|
|
||||||
actions.appendChild(previewButton);
|
|
||||||
}
|
|
||||||
|
|
||||||
header.appendChild(languageLabel);
|
|
||||||
header.appendChild(actions);
|
|
||||||
wrapper.appendChild(header);
|
|
||||||
|
|
||||||
const clonedPre = pre.cloneNode(true) as HTMLElement;
|
|
||||||
wrapper.appendChild(clonedPre);
|
|
||||||
|
|
||||||
pre.parentNode?.replaceChild(wrapper, pre);
|
|
||||||
}
|
|
||||||
|
|
||||||
return mutated ? tempDiv.innerHTML : html;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function processMarkdown(text: string): Promise<string> {
|
|
||||||
try {
|
|
||||||
let normalized = preprocessLaTeX(text);
|
|
||||||
const result = await processor().process(normalized);
|
|
||||||
const html = String(result);
|
|
||||||
const enhancedLinks = enhanceLinks(html);
|
|
||||||
|
|
||||||
return enhanceCodeBlocks(enhancedLinks);
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Markdown processing error:', error);
|
|
||||||
|
|
||||||
// Fallback to plain text with line breaks
|
|
||||||
return text.replace(/\n/g, '<br>');
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes this component's highlight.js theme style from the document head.
|
||||||
|
* Called on component destroy to clean up injected styles.
|
||||||
|
*/
|
||||||
|
function cleanupHighlightTheme() {
|
||||||
|
if (!browser) return;
|
||||||
|
|
||||||
|
const existingTheme = document.getElementById(themeStyleId);
|
||||||
|
existingTheme?.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads the appropriate highlight.js theme based on dark/light mode.
|
||||||
|
* Injects a scoped style element into the document head.
|
||||||
|
* @param isDark - Whether to load the dark theme (true) or light theme (false)
|
||||||
|
*/
|
||||||
|
function loadHighlightTheme(isDark: boolean) {
|
||||||
|
if (!browser) return;
|
||||||
|
|
||||||
|
const existingTheme = document.getElementById(themeStyleId);
|
||||||
|
existingTheme?.remove();
|
||||||
|
|
||||||
|
const style = document.createElement('style');
|
||||||
|
style.id = themeStyleId;
|
||||||
|
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||||
|
|
||||||
|
document.head.appendChild(style);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts code information from a button click target within a code block.
|
||||||
|
* @param target - The clicked button element
|
||||||
|
* @returns Object with rawCode and language, or null if extraction fails
|
||||||
|
*/
|
||||||
function getCodeInfoFromTarget(target: HTMLElement) {
|
function getCodeInfoFromTarget(target: HTMLElement) {
|
||||||
const wrapper = target.closest('.code-block-wrapper');
|
const wrapper = target.closest('.code-block-wrapper');
|
||||||
|
|
||||||
|
|
@ -209,12 +129,7 @@
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const rawCode = codeElement.getAttribute('data-raw-code');
|
const rawCode = codeElement.textContent ?? '';
|
||||||
|
|
||||||
if (rawCode === null) {
|
|
||||||
console.error('No raw code found');
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
|
const languageLabel = wrapper.querySelector<HTMLElement>('.code-language');
|
||||||
const language = languageLabel?.textContent?.trim() || 'text';
|
const language = languageLabel?.textContent?.trim() || 'text';
|
||||||
|
|
@ -222,6 +137,28 @@
|
||||||
return { rawCode, language };
|
return { rawCode, language };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a unique identifier for a HAST node based on its position.
|
||||||
|
* Used for stable block identification during incremental rendering.
|
||||||
|
* @param node - The HAST root content node
|
||||||
|
* @param indexFallback - Fallback index if position is unavailable
|
||||||
|
* @returns Unique string identifier for the node
|
||||||
|
*/
|
||||||
|
function getHastNodeId(node: HastRootContent, indexFallback: number): string {
|
||||||
|
const position = node.position;
|
||||||
|
|
||||||
|
if (position?.start?.offset != null && position?.end?.offset != null) {
|
||||||
|
return `hast-${position.start.offset}-${position.end.offset}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${node.type}-${indexFallback}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles click events on copy buttons within code blocks.
|
||||||
|
* Copies the raw code content to the clipboard.
|
||||||
|
* @param event - The click event from the copy button
|
||||||
|
*/
|
||||||
async function handleCopyClick(event: Event) {
|
async function handleCopyClick(event: Event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
|
|
@ -245,6 +182,25 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles preview dialog open state changes.
|
||||||
|
* Clears preview content when dialog is closed.
|
||||||
|
* @param open - Whether the dialog is being opened or closed
|
||||||
|
*/
|
||||||
|
function handlePreviewDialogOpenChange(open: boolean) {
|
||||||
|
previewDialogOpen = open;
|
||||||
|
|
||||||
|
if (!open) {
|
||||||
|
previewCode = '';
|
||||||
|
previewLanguage = 'text';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles click events on preview buttons within HTML code blocks.
|
||||||
|
* Opens a preview dialog with the rendered HTML content.
|
||||||
|
* @param event - The click event from the preview button
|
||||||
|
*/
|
||||||
function handlePreviewClick(event: Event) {
|
function handlePreviewClick(event: Event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
|
|
@ -266,6 +222,61 @@
|
||||||
previewDialogOpen = true;
|
previewDialogOpen = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Processes markdown content into stable and unstable HTML blocks.
|
||||||
|
* Uses incremental rendering: stable blocks are cached, unstable block is re-rendered.
|
||||||
|
* @param markdown - The raw markdown string to process
|
||||||
|
*/
|
||||||
|
async function processMarkdown(markdown: string) {
|
||||||
|
if (!markdown) {
|
||||||
|
renderedBlocks = [];
|
||||||
|
unstableBlockHtml = '';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalized = preprocessLaTeX(markdown);
|
||||||
|
const processorInstance = processor();
|
||||||
|
const ast = processorInstance.parse(normalized) as MdastRoot;
|
||||||
|
const processedRoot = (await processorInstance.run(ast)) as HastRoot;
|
||||||
|
const processedChildren = processedRoot.children ?? [];
|
||||||
|
const stableCount = Math.max(processedChildren.length - 1, 0);
|
||||||
|
const nextBlocks: MarkdownBlock[] = [];
|
||||||
|
|
||||||
|
for (let index = 0; index < stableCount; index++) {
|
||||||
|
const hastChild = processedChildren[index];
|
||||||
|
const id = getHastNodeId(hastChild, index);
|
||||||
|
const existing = renderedBlocks[index];
|
||||||
|
|
||||||
|
if (existing && existing.id === id) {
|
||||||
|
nextBlocks.push(existing);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const html = stringifyProcessedNode(
|
||||||
|
processorInstance,
|
||||||
|
processedRoot,
|
||||||
|
processedChildren[index]
|
||||||
|
);
|
||||||
|
|
||||||
|
nextBlocks.push({ id, html });
|
||||||
|
}
|
||||||
|
|
||||||
|
let unstableHtml = '';
|
||||||
|
|
||||||
|
if (processedChildren.length > stableCount) {
|
||||||
|
const unstableChild = processedChildren[stableCount];
|
||||||
|
unstableHtml = stringifyProcessedNode(processorInstance, processedRoot, unstableChild);
|
||||||
|
}
|
||||||
|
|
||||||
|
renderedBlocks = nextBlocks;
|
||||||
|
await tick(); // Force DOM sync before updating unstable HTML block
|
||||||
|
unstableBlockHtml = unstableHtml;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attaches click event listeners to copy and preview buttons in code blocks.
|
||||||
|
* Uses data-listener-bound attribute to prevent duplicate bindings.
|
||||||
|
*/
|
||||||
function setupCodeBlockActions() {
|
function setupCodeBlockActions() {
|
||||||
if (!containerRef) return;
|
if (!containerRef) return;
|
||||||
|
|
||||||
|
|
@ -287,40 +298,97 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handlePreviewDialogOpenChange(open: boolean) {
|
/**
|
||||||
previewDialogOpen = open;
|
* Converts a single HAST node to an enhanced HTML string.
|
||||||
|
* Applies link and code block enhancements to the output.
|
||||||
|
* @param processorInstance - The remark/rehype processor instance
|
||||||
|
* @param processedRoot - The full processed HAST root (for context)
|
||||||
|
* @param child - The specific HAST child node to stringify
|
||||||
|
* @returns Enhanced HTML string representation of the node
|
||||||
|
*/
|
||||||
|
function stringifyProcessedNode(
|
||||||
|
processorInstance: ReturnType<typeof processor>,
|
||||||
|
processedRoot: HastRoot,
|
||||||
|
child: unknown
|
||||||
|
) {
|
||||||
|
const root: HastRoot = {
|
||||||
|
...(processedRoot as HastRoot),
|
||||||
|
children: [child as never]
|
||||||
|
};
|
||||||
|
|
||||||
if (!open) {
|
return processorInstance.stringify(root);
|
||||||
previewCode = '';
|
}
|
||||||
previewLanguage = 'text';
|
|
||||||
|
/**
|
||||||
|
* Queues markdown for processing with coalescing support.
|
||||||
|
* Only processes the latest markdown when multiple updates arrive quickly.
|
||||||
|
* @param markdown - The markdown content to render
|
||||||
|
*/
|
||||||
|
async function updateRenderedBlocks(markdown: string) {
|
||||||
|
pendingMarkdown = markdown;
|
||||||
|
|
||||||
|
if (isProcessing) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
isProcessing = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (pendingMarkdown !== null) {
|
||||||
|
const nextMarkdown = pendingMarkdown;
|
||||||
|
pendingMarkdown = null;
|
||||||
|
|
||||||
|
await processMarkdown(nextMarkdown);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to process markdown:', error);
|
||||||
|
renderedBlocks = [];
|
||||||
|
unstableBlockHtml = markdown.replace(/\n/g, '<br>');
|
||||||
|
} finally {
|
||||||
|
isProcessing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
if (content) {
|
const currentMode = mode.current;
|
||||||
processMarkdown(content)
|
const isDark = currentMode === 'dark';
|
||||||
.then((result) => {
|
|
||||||
processedHtml = result;
|
loadHighlightTheme(isDark);
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error('Failed to process markdown:', error);
|
|
||||||
processedHtml = content.replace(/\n/g, '<br>');
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
processedHtml = '';
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
if (containerRef && processedHtml) {
|
updateRenderedBlocks(content);
|
||||||
|
});
|
||||||
|
|
||||||
|
$effect(() => {
|
||||||
|
const hasRenderedBlocks = renderedBlocks.length > 0;
|
||||||
|
const hasUnstableBlock = Boolean(unstableBlockHtml);
|
||||||
|
|
||||||
|
if ((hasRenderedBlocks || hasUnstableBlock) && containerRef) {
|
||||||
setupCodeBlockActions();
|
setupCodeBlockActions();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
cleanupEventListeners();
|
||||||
|
cleanupHighlightTheme();
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div bind:this={containerRef} class={className}>
|
<div bind:this={containerRef} class={className}>
|
||||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
{#each renderedBlocks as block (block.id)}
|
||||||
{@html processedHtml}
|
<div class="markdown-block" data-block-id={block.id}>
|
||||||
|
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||||
|
{@html block.html}
|
||||||
|
</div>
|
||||||
|
{/each}
|
||||||
|
|
||||||
|
{#if unstableBlockHtml}
|
||||||
|
<div class="markdown-block markdown-block--unstable" data-block-id="unstable">
|
||||||
|
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||||
|
{@html unstableBlockHtml}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<CodePreviewDialog
|
<CodePreviewDialog
|
||||||
|
|
@ -331,6 +399,11 @@
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
|
.markdown-block,
|
||||||
|
.markdown-block--unstable {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
|
||||||
/* Base typography styles */
|
/* Base typography styles */
|
||||||
div :global(p:not(:last-child)) {
|
div :global(p:not(:last-child)) {
|
||||||
margin-bottom: 1rem;
|
margin-bottom: 1rem;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
export enum ChatMessageStatsView {
|
||||||
|
GENERATION = 'generation',
|
||||||
|
READING = 'reading'
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
export { AttachmentType } from './attachment';
|
export { AttachmentType } from './attachment';
|
||||||
|
|
||||||
|
export { ChatMessageStatsView } from './chat';
|
||||||
|
|
||||||
export {
|
export {
|
||||||
FileTypeCategory,
|
FileTypeCategory,
|
||||||
FileTypeImage,
|
FileTypeImage,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
/**
|
||||||
|
* Rehype plugin to enhance code blocks with wrapper, header, and action buttons.
|
||||||
|
*
|
||||||
|
* Wraps <pre><code> elements with a container that includes:
|
||||||
|
* - Language label
|
||||||
|
* - Copy button
|
||||||
|
* - Preview button (for HTML code blocks)
|
||||||
|
*
|
||||||
|
* This operates directly on the HAST tree for better performance,
|
||||||
|
* avoiding the need to stringify and re-parse HTML.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Plugin } from 'unified';
|
||||||
|
import type { Root, Element, ElementContent } from 'hast';
|
||||||
|
import { visit } from 'unist-util-visit';
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
interface Window {
|
||||||
|
idxCodeBlock?: number;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const COPY_ICON_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-copy-icon lucide-copy"><rect width="14" height="14" x="8" y="8" rx="2" ry="2"/><path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/></svg>`;
|
||||||
|
|
||||||
|
const PREVIEW_ICON_SVG = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye lucide-eye-icon"><path d="M2.062 12.345a1 1 0 0 1 0-.69C3.5 7.73 7.36 5 12 5s8.5 2.73 9.938 6.655a1 1 0 0 1 0 .69C20.5 16.27 16.64 19 12 19s-8.5-2.73-9.938-6.655"/><circle cx="12" cy="12" r="3"/></svg>`;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an SVG element node from raw SVG string.
|
||||||
|
* Since we can't parse HTML in HAST directly, we use the raw property.
|
||||||
|
*/
|
||||||
|
function createRawHtmlElement(html: string): Element {
|
||||||
|
return {
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'span',
|
||||||
|
properties: {},
|
||||||
|
children: [{ type: 'raw', value: html } as unknown as ElementContent]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function createCopyButton(codeId: string): Element {
|
||||||
|
return {
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'button',
|
||||||
|
properties: {
|
||||||
|
className: ['copy-code-btn'],
|
||||||
|
'data-code-id': codeId,
|
||||||
|
title: 'Copy code',
|
||||||
|
type: 'button'
|
||||||
|
},
|
||||||
|
children: [createRawHtmlElement(COPY_ICON_SVG)]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function createPreviewButton(codeId: string): Element {
|
||||||
|
return {
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'button',
|
||||||
|
properties: {
|
||||||
|
className: ['preview-code-btn'],
|
||||||
|
'data-code-id': codeId,
|
||||||
|
title: 'Preview code',
|
||||||
|
type: 'button'
|
||||||
|
},
|
||||||
|
children: [createRawHtmlElement(PREVIEW_ICON_SVG)]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function createHeader(language: string, codeId: string): Element {
|
||||||
|
const actions: Element[] = [createCopyButton(codeId)];
|
||||||
|
|
||||||
|
if (language.toLowerCase() === 'html') {
|
||||||
|
actions.push(createPreviewButton(codeId));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'div',
|
||||||
|
properties: { className: ['code-block-header'] },
|
||||||
|
children: [
|
||||||
|
{
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'span',
|
||||||
|
properties: { className: ['code-language'] },
|
||||||
|
children: [{ type: 'text', value: language }]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'div',
|
||||||
|
properties: { className: ['code-block-actions'] },
|
||||||
|
children: actions
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function createWrapper(header: Element, preElement: Element): Element {
|
||||||
|
return {
|
||||||
|
type: 'element',
|
||||||
|
tagName: 'div',
|
||||||
|
properties: { className: ['code-block-wrapper'] },
|
||||||
|
children: [header, preElement]
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractLanguage(codeElement: Element): string {
|
||||||
|
const className = codeElement.properties?.className;
|
||||||
|
if (!Array.isArray(className)) return 'text';
|
||||||
|
|
||||||
|
for (const cls of className) {
|
||||||
|
if (typeof cls === 'string' && cls.startsWith('language-')) {
|
||||||
|
return cls.replace('language-', '');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 'text';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a unique code block ID using a global counter.
|
||||||
|
*/
|
||||||
|
function generateCodeId(): string {
|
||||||
|
if (typeof window !== 'undefined') {
|
||||||
|
return `code-${(window.idxCodeBlock = (window.idxCodeBlock ?? 0) + 1)}`;
|
||||||
|
}
|
||||||
|
// Fallback for SSR - use timestamp + random
|
||||||
|
return `code-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rehype plugin to enhance code blocks with wrapper, header, and action buttons.
|
||||||
|
* This plugin wraps <pre><code> elements with a container that includes:
|
||||||
|
* - Language label
|
||||||
|
* - Copy button
|
||||||
|
* - Preview button (for HTML code blocks)
|
||||||
|
*/
|
||||||
|
export const rehypeEnhanceCodeBlocks: Plugin<[], Root> = () => {
|
||||||
|
return (tree: Root) => {
|
||||||
|
visit(tree, 'element', (node: Element, index, parent) => {
|
||||||
|
if (node.tagName !== 'pre' || !parent || index === undefined) return;
|
||||||
|
|
||||||
|
const codeElement = node.children.find(
|
||||||
|
(child): child is Element => child.type === 'element' && child.tagName === 'code'
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!codeElement) return;
|
||||||
|
|
||||||
|
const language = extractLanguage(codeElement);
|
||||||
|
const codeId = generateCodeId();
|
||||||
|
|
||||||
|
codeElement.properties = {
|
||||||
|
...codeElement.properties,
|
||||||
|
'data-code-id': codeId
|
||||||
|
};
|
||||||
|
|
||||||
|
const header = createHeader(language, codeId);
|
||||||
|
const wrapper = createWrapper(header, node);
|
||||||
|
|
||||||
|
// Replace pre with wrapper in parent
|
||||||
|
(parent.children as ElementContent[])[index] = wrapper;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Rehype plugin to enhance links with security attributes.
|
||||||
|
*
|
||||||
|
* Adds target="_blank" and rel="noopener noreferrer" to all anchor elements,
|
||||||
|
* ensuring external links open in new tabs safely.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Plugin } from 'unified';
|
||||||
|
import type { Root, Element } from 'hast';
|
||||||
|
import { visit } from 'unist-util-visit';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rehype plugin that adds security attributes to all links.
|
||||||
|
* This plugin ensures external links open in new tabs safely by adding:
|
||||||
|
* - target="_blank"
|
||||||
|
* - rel="noopener noreferrer"
|
||||||
|
*/
|
||||||
|
export const rehypeEnhanceLinks: Plugin<[], Root> = () => {
|
||||||
|
return (tree: Root) => {
|
||||||
|
visit(tree, 'element', (node: Element) => {
|
||||||
|
if (node.tagName !== 'a') return;
|
||||||
|
|
||||||
|
const props = node.properties ?? {};
|
||||||
|
|
||||||
|
// Only modify if href exists
|
||||||
|
if (!props.href) return;
|
||||||
|
|
||||||
|
props.target = '_blank';
|
||||||
|
props.rel = 'noopener noreferrer';
|
||||||
|
node.properties = props;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
@ -171,6 +171,7 @@ class ChatStore {
|
||||||
updateProcessingStateFromTimings(
|
updateProcessingStateFromTimings(
|
||||||
timingData: {
|
timingData: {
|
||||||
prompt_n: number;
|
prompt_n: number;
|
||||||
|
prompt_ms?: number;
|
||||||
predicted_n: number;
|
predicted_n: number;
|
||||||
predicted_per_second: number;
|
predicted_per_second: number;
|
||||||
cache_n: number;
|
cache_n: number;
|
||||||
|
|
@ -212,6 +213,7 @@ class ChatStore {
|
||||||
if (message.role === 'assistant' && message.timings) {
|
if (message.role === 'assistant' && message.timings) {
|
||||||
const restoredState = this.parseTimingData({
|
const restoredState = this.parseTimingData({
|
||||||
prompt_n: message.timings.prompt_n || 0,
|
prompt_n: message.timings.prompt_n || 0,
|
||||||
|
prompt_ms: message.timings.prompt_ms,
|
||||||
predicted_n: message.timings.predicted_n || 0,
|
predicted_n: message.timings.predicted_n || 0,
|
||||||
predicted_per_second:
|
predicted_per_second:
|
||||||
message.timings.predicted_n && message.timings.predicted_ms
|
message.timings.predicted_n && message.timings.predicted_ms
|
||||||
|
|
@ -282,6 +284,7 @@ class ChatStore {
|
||||||
|
|
||||||
private parseTimingData(timingData: Record<string, unknown>): ApiProcessingState | null {
|
private parseTimingData(timingData: Record<string, unknown>): ApiProcessingState | null {
|
||||||
const promptTokens = (timingData.prompt_n as number) || 0;
|
const promptTokens = (timingData.prompt_n as number) || 0;
|
||||||
|
const promptMs = (timingData.prompt_ms as number) || undefined;
|
||||||
const predictedTokens = (timingData.predicted_n as number) || 0;
|
const predictedTokens = (timingData.predicted_n as number) || 0;
|
||||||
const tokensPerSecond = (timingData.predicted_per_second as number) || 0;
|
const tokensPerSecond = (timingData.predicted_per_second as number) || 0;
|
||||||
const cacheTokens = (timingData.cache_n as number) || 0;
|
const cacheTokens = (timingData.cache_n as number) || 0;
|
||||||
|
|
@ -320,6 +323,7 @@ class ChatStore {
|
||||||
speculative: false,
|
speculative: false,
|
||||||
progressPercent,
|
progressPercent,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
|
promptMs,
|
||||||
cacheTokens
|
cacheTokens
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -536,6 +540,7 @@ class ChatStore {
|
||||||
this.updateProcessingStateFromTimings(
|
this.updateProcessingStateFromTimings(
|
||||||
{
|
{
|
||||||
prompt_n: timings?.prompt_n || 0,
|
prompt_n: timings?.prompt_n || 0,
|
||||||
|
prompt_ms: timings?.prompt_ms,
|
||||||
predicted_n: timings?.predicted_n || 0,
|
predicted_n: timings?.predicted_n || 0,
|
||||||
predicted_per_second: tokensPerSecond,
|
predicted_per_second: tokensPerSecond,
|
||||||
cache_n: timings?.cache_n || 0,
|
cache_n: timings?.cache_n || 0,
|
||||||
|
|
@ -768,10 +773,11 @@ class ChatStore {
|
||||||
content: streamingState.response
|
content: streamingState.response
|
||||||
};
|
};
|
||||||
if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking;
|
if (lastMessage.thinking?.trim()) updateData.thinking = lastMessage.thinking;
|
||||||
const lastKnownState = this.getCurrentProcessingStateSync();
|
const lastKnownState = this.getProcessingState(conversationId);
|
||||||
if (lastKnownState) {
|
if (lastKnownState) {
|
||||||
updateData.timings = {
|
updateData.timings = {
|
||||||
prompt_n: lastKnownState.promptTokens || 0,
|
prompt_n: lastKnownState.promptTokens || 0,
|
||||||
|
prompt_ms: lastKnownState.promptMs,
|
||||||
predicted_n: lastKnownState.tokensDecoded || 0,
|
predicted_n: lastKnownState.tokensDecoded || 0,
|
||||||
cache_n: lastKnownState.cacheTokens || 0,
|
cache_n: lastKnownState.cacheTokens || 0,
|
||||||
predicted_ms:
|
predicted_ms:
|
||||||
|
|
@ -1253,6 +1259,7 @@ class ChatStore {
|
||||||
this.updateProcessingStateFromTimings(
|
this.updateProcessingStateFromTimings(
|
||||||
{
|
{
|
||||||
prompt_n: timings?.prompt_n || 0,
|
prompt_n: timings?.prompt_n || 0,
|
||||||
|
prompt_ms: timings?.prompt_ms,
|
||||||
predicted_n: timings?.predicted_n || 0,
|
predicted_n: timings?.predicted_n || 0,
|
||||||
predicted_per_second: tokensPerSecond,
|
predicted_per_second: tokensPerSecond,
|
||||||
cache_n: timings?.cache_n || 0,
|
cache_n: timings?.cache_n || 0,
|
||||||
|
|
|
||||||
|
|
@ -342,6 +342,7 @@ export interface ApiProcessingState {
|
||||||
// Progress information from prompt_progress
|
// Progress information from prompt_progress
|
||||||
progressPercent?: number;
|
progressPercent?: number;
|
||||||
promptTokens?: number;
|
promptTokens?: number;
|
||||||
|
promptMs?: number;
|
||||||
cacheTokens?: number;
|
cacheTokens?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue